ayush2917 commited on
Commit
34fbe97
·
verified ·
1 Parent(s): 58c35ab

Update src/retrieval.py

Browse files
Files changed (1) hide show
  1. src/retrieval.py +8 -14
src/retrieval.py CHANGED
@@ -1,29 +1,23 @@
1
- import numpy as np
2
  from sentence_transformers import SentenceTransformer
 
3
  import json
4
- import os
5
- from typing import List, Dict
6
 
7
  class DocumentRetriever:
8
  def __init__(self, model_name='all-MiniLM-L6-v2'):
9
  self.model = SentenceTransformer(model_name)
10
  self.documents = self._load_documents()
11
  self.doc_embeddings = self._embed_documents()
12
-
13
- def _load_documents(self) -> List[Dict]:
14
  with open('data/rupeia_document.json', 'r') as f:
15
  return json.load(f)
16
-
17
- def _embed_documents(self) -> np.ndarray:
18
  texts = [doc['content'] for doc in self.documents]
19
  return self.model.encode(texts)
20
-
21
- def retrieve(self, query: str, top_k: int = 3) -> List[Dict]:
22
  query_embedding = self.model.encode(query)
23
  scores = np.dot(self.doc_embeddings, query_embedding)
24
  top_indices = np.argsort(scores)[-top_k:][::-1]
25
- return [self.documents[i] for i in top_indices]
26
-
27
- def retrieve_relevant_documents(query: str) -> List[Dict]:
28
- retriever = DocumentRetriever()
29
- return retriever.retrieve(query)
 
 
1
  from sentence_transformers import SentenceTransformer
2
+ import numpy as np
3
  import json
 
 
4
 
5
  class DocumentRetriever:
6
  def __init__(self, model_name='all-MiniLM-L6-v2'):
7
  self.model = SentenceTransformer(model_name)
8
  self.documents = self._load_documents()
9
  self.doc_embeddings = self._embed_documents()
10
+
11
+ def _load_documents(self):
12
  with open('data/rupeia_document.json', 'r') as f:
13
  return json.load(f)
14
+
15
+ def _embed_documents(self):
16
  texts = [doc['content'] for doc in self.documents]
17
  return self.model.encode(texts)
18
+
19
+ def retrieve(self, query: str, top_k: int = 3):
20
  query_embedding = self.model.encode(query)
21
  scores = np.dot(self.doc_embeddings, query_embedding)
22
  top_indices = np.argsort(scores)[-top_k:][::-1]
23
+ return [self.documents[i] for i in top_indices]