ayush2917 commited on
Commit
c06718f
·
verified ·
1 Parent(s): f808bd5

Update src/retrieval.py

Browse files
Files changed (1) hide show
  1. src/retrieval.py +8 -7
src/retrieval.py CHANGED
@@ -1,22 +1,23 @@
1
- from sentence_transformers import SentenceTransformer
2
  import numpy as np
 
3
  import json
 
4
 
5
  class DocumentRetriever:
6
  def __init__(self, model_name='all-MiniLM-L6-v2'):
7
  self.model = SentenceTransformer(model_name)
8
  self.documents = self._load_documents()
9
  self.doc_embeddings = self._embed_documents()
10
-
11
- def _load_documents(self):
12
  with open('data/rupeia_document.json', 'r') as f:
13
  return json.load(f)
14
-
15
- def _embed_documents(self):
16
  texts = [doc['content'] for doc in self.documents]
17
  return self.model.encode(texts)
18
-
19
- def retrieve(self, query: str, top_k: int = 3):
20
  query_embedding = self.model.encode(query)
21
  scores = np.dot(self.doc_embeddings, query_embedding)
22
  top_indices = np.argsort(scores)[-top_k:][::-1]
 
 
1
  import numpy as np
2
+ from sentence_transformers import SentenceTransformer
3
  import json
4
+ from typing import List, Dict
5
 
6
  class DocumentRetriever:
7
  def __init__(self, model_name='all-MiniLM-L6-v2'):
8
  self.model = SentenceTransformer(model_name)
9
  self.documents = self._load_documents()
10
  self.doc_embeddings = self._embed_documents()
11
+
12
+ def _load_documents(self) -> List[Dict]:
13
  with open('data/rupeia_document.json', 'r') as f:
14
  return json.load(f)
15
+
16
+ def _embed_documents(self) -> np.ndarray:
17
  texts = [doc['content'] for doc in self.documents]
18
  return self.model.encode(texts)
19
+
20
+ def retrieve(self, query: str, top_k: int = 3) -> List[Dict]:
21
  query_embedding = self.model.encode(query)
22
  scores = np.dot(self.doc_embeddings, query_embedding)
23
  top_indices = np.argsort(scores)[-top_k:][::-1]