Spaces:
Running
Running
| from typing import List, Tuple, Dict, Literal | |
| import numpy as np | |
| from rank_bm25 import BM25Okapi | |
| from ..base import BaseRetriever | |
| from ..utils import tokenise | |
| class BM25Retriever(BaseRetriever): | |
| """BM25 sparse retrieval for AgentBase.""" | |
| def __init__(self, db_path: str, index_config: Literal["naive", "v1"], **bm25_params): | |
| super().__init__(db_path, index_config) | |
| self.bm25_params = bm25_params | |
| self.index = None | |
| self.build_index() | |
| def build_index(self): | |
| """ | |
| Load documents and build BM25 index. | |
| """ | |
| self.agent_ids, self.documents = self.indexing_func[self.index_config]() | |
| tokenised_docs = [tokenise(doc) for doc in self.documents] | |
| self.index = BM25Okapi(tokenised_docs, **self.bm25_params) | |
| def retrieve(self, query: str, top_k: int = 10): | |
| """ | |
| Retrieve top-k agents using BM25. | |
| """ | |
| tokenized_query = tokenise(query) | |
| scores = self.index.get_scores(tokenized_query) | |
| top_indices = np.argsort(scores)[-top_k:][::-1] | |
| return [(self.agent_ids[idx], float(scores[idx])) for idx in top_indices] |