support-system / src /retrieval.py
ayush2917's picture
Update src/retrieval.py
2c3faf7 verified
import os
import json
import numpy as np
from sentence_transformers import SentenceTransformer
from typing import List, Dict
import logging
import time
logger = logging.getLogger(__name__)
class DocumentRetriever:
def __init__(self, model_name='all-MiniLM-L6-v2', data_path='data/rupeia_document.json', cache_folder=None):
"""
Initialize the DocumentRetriever with a SentenceTransformer model.
Args:
model_name (str): Name of the SentenceTransformer model (default: 'all-MiniLM-L6-v2').
data_path (str): Path to the document JSON file (default: 'data/rupeia_document.json').
cache_folder (str, optional): Directory to cache model files (default: None).
"""
logger.info(f"Initializing DocumentRetriever with model: {model_name}, cache_folder: {cache_folder}")
try:
self.model = SentenceTransformer(model_name, cache_folder=cache_folder)
except Exception as e:
logger.error(f"Failed to load SentenceTransformer model: {str(e)}")
raise
self.data_path = data_path
self.documents = self._load_documents()
self.doc_embeddings = self._load_or_compute_embeddings()
def _load_documents(self) -> List[Dict]:
"""Load documents from the JSON file."""
try:
with open(self.data_path, 'r') as f:
documents = json.load(f)
logger.info(f"Loaded {len(documents)} documents from {self.data_path}")
return documents
except FileNotFoundError:
logger.warning(f"Data file not found at {self.data_path}, using empty documents")
return []
except json.JSONDecodeError:
logger.warning(f"Invalid JSON in {self.data_path}, using empty documents")
return []
def _load_or_compute_embeddings(self) -> np.ndarray:
"""Load cached embeddings or compute new ones."""
embedding_cache_path = 'data/doc_embeddings.npy'
if not self.documents:
logger.info("No documents to embed, returning empty embeddings")
return np.array([])
# Check for cached embeddings
if os.path.exists(embedding_cache_path):
try:
embeddings = np.load(embedding_cache_path)
if embeddings.shape[0] == len(self.documents):
logger.info(f"Loaded {embeddings.shape[0]} cached embeddings from {embedding_cache_path}")
return embeddings
else:
logger.warning(f"Cached embeddings shape mismatch, recomputing...")
except Exception as e:
logger.warning(f"Failed to load cached embeddings: {str(e)}, recomputing...")
# Compute new embeddings
texts = [doc['content'] for doc in self.documents]
logger.info(f"Computing embeddings for {len(texts)} documents...")
start_time = time.time()
embeddings = self.model.encode(texts, batch_size=32, show_progress_bar=True)
logger.info(f"Embedding {len(texts)} documents took {time.time() - start_time:.2f} seconds")
# Cache embeddings
try:
os.makedirs('data', exist_ok=True)
np.save(embedding_cache_path, embeddings)
logger.info(f"Saved embeddings to {embedding_cache_path}")
except Exception as e:
logger.warning(f"Failed to save embeddings: {str(e)}")
return embeddings
def retrieve(self, query: str, top_k: int = 3) -> List[Dict]:
"""Retrieve the top-k most relevant documents for a given query."""
if not self.documents:
logger.warning("No documents available for retrieval")
return []
logger.info(f"Retrieving top {top_k} documents for query: {query}")
query_embedding = self.model.encode(query)
scores = np.dot(self.doc_embeddings, query_embedding)
top_indices = np.argsort(scores)[-top_k:][::-1]
results = [self.documents[i] for i in top_indices]
logger.info(f"Retrieved {len(results)} documents")
return results