| |
| from rank_bm25 import BM25Okapi |
| import json |
| from huggingface_hub import hf_hub_download |
| from datasets import load_dataset |
|
|
| class BM25Retriever: |
| def __init__(self, corpus_dataset_id: str): |
| """ |
| Initializes the BM25 retriever. |
| Args: |
| corpus_dataset_id: Hugging Face dataset ID where your corpus is stored. |
| """ |
| self.corpus_dataset_id = "Stergios-Konstantinidis/MNLP_M2_rag_dataset" |
| self.bm25 = None |
| self.documents = [] |
|
|
| self._load_corpus_and_build_bm25() |
|
|
| def _load_corpus_and_build_bm25(self): |
| print(f"Loading corpus from dataset: {self.corpus_dataset_id}") |
| dataset = load_dataset(self.corpus_dataset_id, split="train") |
|
|
| |
| self.documents = [doc['text'] for doc in dataset] |
|
|
|
|
| tokenized_corpus = [doc.split(" ") for doc in self.documents] |
| self.bm25 = BM25Okapi(tokenized_corpus) |
| print("BM25 index built successfully.") |
|
|
| def retrieve(self, query: str, top_k: int = 5): |
| tokenized_query = query.lower().split(" ") |
| doc_scores = self.bm25.get_scores(tokenized_query) |
|
|
| |
| top_indices = doc_scores.argsort()[-top_k:][::-1] |
|
|
| results = [] |
| for i in top_indices: |
| results.append({ |
| "text": self.documents[i], |
| "score": doc_scores[i] |
| }) |
| return results |
|
|
| if __name__ == "__main__": |
| |
| retriever = BM25Retriever("Stergios-Konstantinidis/MNLP_M2_rag_model") |
| results = retriever.retrieve("What is a fox?", top_k=2) |
| for r in results: |
| print(f"Score: {r['score']}, Text: {r['text']}") |