import os import json import numpy as np from sentence_transformers import SentenceTransformer from typing import List, Dict import logging import time logger = logging.getLogger(__name__) class DocumentRetriever: def __init__(self, model_name='all-MiniLM-L6-v2', data_path='data/rupeia_document.json', cache_folder=None): """ Initialize the DocumentRetriever with a SentenceTransformer model. Args: model_name (str): Name of the SentenceTransformer model (default: 'all-MiniLM-L6-v2'). data_path (str): Path to the document JSON file (default: 'data/rupeia_document.json'). cache_folder (str, optional): Directory to cache model files (default: None). """ logger.info(f"Initializing DocumentRetriever with model: {model_name}, cache_folder: {cache_folder}") try: self.model = SentenceTransformer(model_name, cache_folder=cache_folder) except Exception as e: logger.error(f"Failed to load SentenceTransformer model: {str(e)}") raise self.data_path = data_path self.documents = self._load_documents() self.doc_embeddings = self._load_or_compute_embeddings() def _load_documents(self) -> List[Dict]: """Load documents from the JSON file.""" try: with open(self.data_path, 'r') as f: documents = json.load(f) logger.info(f"Loaded {len(documents)} documents from {self.data_path}") return documents except FileNotFoundError: logger.warning(f"Data file not found at {self.data_path}, using empty documents") return [] except json.JSONDecodeError: logger.warning(f"Invalid JSON in {self.data_path}, using empty documents") return [] def _load_or_compute_embeddings(self) -> np.ndarray: """Load cached embeddings or compute new ones.""" embedding_cache_path = 'data/doc_embeddings.npy' if not self.documents: logger.info("No documents to embed, returning empty embeddings") return np.array([]) # Check for cached embeddings if os.path.exists(embedding_cache_path): try: embeddings = np.load(embedding_cache_path) if embeddings.shape[0] == len(self.documents): logger.info(f"Loaded {embeddings.shape[0]} cached embeddings from {embedding_cache_path}") return embeddings else: logger.warning(f"Cached embeddings shape mismatch, recomputing...") except Exception as e: logger.warning(f"Failed to load cached embeddings: {str(e)}, recomputing...") # Compute new embeddings texts = [doc['content'] for doc in self.documents] logger.info(f"Computing embeddings for {len(texts)} documents...") start_time = time.time() embeddings = self.model.encode(texts, batch_size=32, show_progress_bar=True) logger.info(f"Embedding {len(texts)} documents took {time.time() - start_time:.2f} seconds") # Cache embeddings try: os.makedirs('data', exist_ok=True) np.save(embedding_cache_path, embeddings) logger.info(f"Saved embeddings to {embedding_cache_path}") except Exception as e: logger.warning(f"Failed to save embeddings: {str(e)}") return embeddings def retrieve(self, query: str, top_k: int = 3) -> List[Dict]: """Retrieve the top-k most relevant documents for a given query.""" if not self.documents: logger.warning("No documents available for retrieval") return [] logger.info(f"Retrieving top {top_k} documents for query: {query}") query_embedding = self.model.encode(query) scores = np.dot(self.doc_embeddings, query_embedding) top_indices = np.argsort(scores)[-top_k:][::-1] results = [self.documents[i] for i in top_indices] logger.info(f"Retrieved {len(results)} documents") return results