MNLP_M2_rag_model / bm25_retriever.py
Stergios-Konstantinidis's picture
Upload 3 files
58aa838 verified
# bm25_retriever.py
from rank_bm25 import BM25Okapi
import json
from huggingface_hub import hf_hub_download
from datasets import load_dataset
class BM25Retriever:
def __init__(self, corpus_dataset_id: str):
"""
Initializes the BM25 retriever.
Args:
corpus_dataset_id: Hugging Face dataset ID where your corpus is stored.
"""
self.corpus_dataset_id = "Stergios-Konstantinidis/MNLP_M2_rag_dataset"
self.bm25 = None
self.documents = []
self._load_corpus_and_build_bm25()
def _load_corpus_and_build_bm25(self):
print(f"Loading corpus from dataset: {self.corpus_dataset_id}")
dataset = load_dataset(self.corpus_dataset_id, split="train")
# Assuming your dataset has 'id' and 'text' columns
self.documents = [doc['text'] for doc in dataset]
tokenized_corpus = [doc.split(" ") for doc in self.documents]
self.bm25 = BM25Okapi(tokenized_corpus)
print("BM25 index built successfully.")
def retrieve(self, query: str, top_k: int = 5):
tokenized_query = query.lower().split(" ")
doc_scores = self.bm25.get_scores(tokenized_query)
# Get top k document indices
top_indices = doc_scores.argsort()[-top_k:][::-1]
results = []
for i in top_indices:
results.append({
"text": self.documents[i],
"score": doc_scores[i]
})
return results
if __name__ == "__main__":
# Example usage (replace with your actual dataset ID)
retriever = BM25Retriever("Stergios-Konstantinidis/MNLP_M2_rag_model")
results = retriever.retrieve("What is a fox?", top_k=2)
for r in results:
print(f"Score: {r['score']}, Text: {r['text']}")