j-js commited on
Commit
7090817
·
verified ·
1 Parent(s): b3d298d

Create retrieval_engine.py

Browse files
Files changed (1) hide show
  1. retrieval_engine.py +34 -0
retrieval_engine.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from sentence_transformers import SentenceTransformer
3
+ import faiss
4
+ import numpy as np
5
+
6
+ class RetrievalEngine:
7
+
8
+ def __init__(self):
9
+ self.dataset = load_dataset("YOUR_USERNAME/gmat-quant-corpus", split="train")
10
+
11
+ self.model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
12
+
13
+ texts = [row["text"] for row in self.dataset]
14
+
15
+ embeddings = self.model.encode(texts, convert_to_numpy=True)
16
+
17
+ dim = embeddings.shape[1]
18
+ self.index = faiss.IndexFlatL2(dim)
19
+ self.index.add(embeddings)
20
+
21
+ self.texts = texts
22
+
23
+ def search(self, query, k=3):
24
+
25
+ q_emb = self.model.encode([query], convert_to_numpy=True)
26
+
27
+ distances, indices = self.index.search(q_emb, k)
28
+
29
+ results = []
30
+
31
+ for idx in indices[0]:
32
+ results.append(self.texts[idx])
33
+
34
+ return results