Stergios-Konstantinidis commited on
Commit
ddd21b8
·
verified ·
1 Parent(s): c36eeab

Upload 2 files

Browse files
Files changed (2) hide show
  1. rag bm25 pipeline.py +52 -0
  2. requirements.txt +3 -0
rag bm25 pipeline.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # bm25_retriever.py
2
+ from rank_bm25 import BM25Okapi
3
+ import json
4
+ from huggingface_hub import hf_hub_download
5
+ from datasets import load_dataset
6
+
7
+ class BM25Retriever:
8
+ def __init__(self, corpus_dataset_id: str):
9
+ """
10
+ Initializes the BM25 retriever.
11
+ Args:
12
+ corpus_dataset_id: Hugging Face dataset ID where your corpus is stored.
13
+ """
14
+ self.corpus_dataset_id = "Stergios-Konstantinidis/MNLP_M2_rag_dataset"
15
+ self.bm25 = None
16
+ self.documents = []
17
+
18
+ self._load_corpus_and_build_bm25()
19
+
20
+ def _load_corpus_and_build_bm25(self):
21
+ print(f"Loading corpus from dataset: {self.corpus_dataset_id}")
22
+ dataset = load_dataset(self.corpus_dataset_id, split="train")
23
+
24
+ # Assuming your dataset has 'id' and 'text' columns
25
+ self.documents = [doc['text'] for doc in dataset]
26
+
27
+
28
+ tokenized_corpus = [doc.split(" ") for doc in self.documents]
29
+ self.bm25 = BM25Okapi(tokenized_corpus)
30
+ print("BM25 index built successfully.")
31
+
32
+ def retrieve(self, query: str, top_k: int = 5):
33
+ tokenized_query = query.lower().split(" ")
34
+ doc_scores = self.bm25.get_scores(tokenized_query)
35
+
36
+ # Get top k document indices
37
+ top_indices = doc_scores.argsort()[-top_k:][::-1]
38
+
39
+ results = []
40
+ for i in top_indices:
41
+ results.append({
42
+ "text": self.documents[i],
43
+ "score": doc_scores[i]
44
+ })
45
+ return results
46
+
47
+ if __name__ == "__main__":
48
+ # Example usage (replace with your actual dataset ID)
49
+ retriever = BM25Retriever("Stergios-Konstantinidis/MNLP_M2_rag_model")
50
+ results = retriever.retrieve("What is a fox?", top_k=2)
51
+ for r in results:
52
+ print(f"Score: {r['score']}, Text: {r['text']}")
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ datasets
2
+ huggingface_hub
3
+ rank_bm25