ayush2917 commited on
Commit
40c75fc
·
verified ·
1 Parent(s): 49c9e97

Update src/retrieval.py

Browse files
Files changed (1) hide show
  1. src/retrieval.py +31 -7
src/retrieval.py CHANGED
@@ -5,7 +5,6 @@ from sentence_transformers import SentenceTransformer
5
  from typing import List, Dict
6
  import logging
7
 
8
- # Configure logging
9
  logger = logging.getLogger(__name__)
10
 
11
  class DocumentRetriever:
@@ -26,7 +25,7 @@ class DocumentRetriever:
26
  raise
27
  self.data_path = data_path
28
  self.documents = self._load_documents()
29
- self.doc_embeddings = self._embed_documents()
30
 
31
  def _load_documents(self) -> List[Dict]:
32
  """Load documents from the JSON file."""
@@ -42,15 +41,40 @@ class DocumentRetriever:
42
  logger.warning(f"Invalid JSON in {self.data_path}, using empty documents")
43
  return []
44
 
45
- def _embed_documents(self) -> np.ndarray:
46
- """Embed document contents using the SentenceTransformer model."""
 
47
  if not self.documents:
48
  logger.info("No documents to embed, returning empty embeddings")
49
  return np.array([])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  texts = [doc['content'] for doc in self.documents]
51
- logger.info(f"Embedding {len(texts)} documents...")
52
- embeddings = self.model.encode(texts)
53
- logger.info("Document embeddings generated successfully")
 
 
 
 
 
 
 
 
 
 
54
  return embeddings
55
 
56
  def retrieve(self, query: str, top_k: int = 3) -> List[Dict]:
 
5
  from typing import List, Dict
6
  import logging
7
 
 
8
  logger = logging.getLogger(__name__)
9
 
10
  class DocumentRetriever:
 
25
  raise
26
  self.data_path = data_path
27
  self.documents = self._load_documents()
28
+ self.doc_embeddings = self._load_or_compute_embeddings()
29
 
30
  def _load_documents(self) -> List[Dict]:
31
  """Load documents from the JSON file."""
 
41
  logger.warning(f"Invalid JSON in {self.data_path}, using empty documents")
42
  return []
43
 
44
+ def _load_or_compute_embeddings(self) -> np.ndarray:
45
+ """Load cached embeddings or compute new ones."""
46
+ embedding_cache_path = 'data/doc_embeddings.npy'
47
  if not self.documents:
48
  logger.info("No documents to embed, returning empty embeddings")
49
  return np.array([])
50
+
51
+ # Check for cached embeddings
52
+ if os.path.exists(embedding_cache_path):
53
+ try:
54
+ embeddings = np.load(embedding_cache_path)
55
+ if embeddings.shape[0] == len(self.documents):
56
+ logger.info(f"Loaded {embeddings.shape[0]} cached embeddings from {embedding_cache_path}")
57
+ return embeddings
58
+ else:
59
+ logger.warning(f"Cached embeddings shape mismatch, recomputing...")
60
+ except Exception as e:
61
+ logger.warning(f"Failed to load cached embeddings: {str(e)}, recomputing...")
62
+
63
+ # Compute new embeddings
64
  texts = [doc['content'] for doc in self.documents]
65
+ logger.info(f"Computing embeddings for {len(texts)} documents...")
66
+ start_time = time.time()
67
+ embeddings = self.model.encode(texts, batch_size=32, show_progress_bar=True)
68
+ logger.info(f"Embedding {len(texts)} documents took {time.time() - start_time:.2f} seconds")
69
+
70
+ # Cache embeddings
71
+ try:
72
+ os.makedirs('data', exist_ok=True)
73
+ np.save(embedding_cache_path, embeddings)
74
+ logger.info(f"Saved embeddings to {embedding_cache_path}")
75
+ except Exception as e:
76
+ logger.warning(f"Failed to save embeddings: {str(e)}")
77
+
78
  return embeddings
79
 
80
  def retrieve(self, query: str, top_k: int = 3) -> List[Dict]: