from typing import List, Tuple, Dict, Literal import numpy as np from rank_bm25 import BM25Okapi from ..base import BaseRetriever from ..utils import tokenise class BM25Retriever(BaseRetriever): """BM25 sparse retrieval for AgentBase.""" def __init__(self, db_path: str, index_config: Literal["naive", "v1"], **bm25_params): super().__init__(db_path, index_config) self.bm25_params = bm25_params self.index = None self.build_index() def build_index(self): """ Load documents and build BM25 index. """ self.agent_ids, self.documents = self.indexing_func[self.index_config]() tokenised_docs = [tokenise(doc) for doc in self.documents] self.index = BM25Okapi(tokenised_docs, **self.bm25_params) def retrieve(self, query: str, top_k: int = 10): """ Retrieve top-k agents using BM25. """ tokenized_query = tokenise(query) scores = self.index.get_scores(tokenized_query) top_indices = np.argsort(scores)[-top_k:][::-1] return [(self.agent_ids[idx], float(scores[idx])) for idx in top_indices]