Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from typing import List, Dict | |
| import logging | |
| import time | |
| logger = logging.getLogger(__name__) | |
| class DocumentRetriever: | |
| def __init__(self, model_name='all-MiniLM-L6-v2', data_path='data/rupeia_document.json', cache_folder=None): | |
| """ | |
| Initialize the DocumentRetriever with a SentenceTransformer model. | |
| Args: | |
| model_name (str): Name of the SentenceTransformer model (default: 'all-MiniLM-L6-v2'). | |
| data_path (str): Path to the document JSON file (default: 'data/rupeia_document.json'). | |
| cache_folder (str, optional): Directory to cache model files (default: None). | |
| """ | |
| logger.info(f"Initializing DocumentRetriever with model: {model_name}, cache_folder: {cache_folder}") | |
| try: | |
| self.model = SentenceTransformer(model_name, cache_folder=cache_folder) | |
| except Exception as e: | |
| logger.error(f"Failed to load SentenceTransformer model: {str(e)}") | |
| raise | |
| self.data_path = data_path | |
| self.documents = self._load_documents() | |
| self.doc_embeddings = self._load_or_compute_embeddings() | |
| def _load_documents(self) -> List[Dict]: | |
| """Load documents from the JSON file.""" | |
| try: | |
| with open(self.data_path, 'r') as f: | |
| documents = json.load(f) | |
| logger.info(f"Loaded {len(documents)} documents from {self.data_path}") | |
| return documents | |
| except FileNotFoundError: | |
| logger.warning(f"Data file not found at {self.data_path}, using empty documents") | |
| return [] | |
| except json.JSONDecodeError: | |
| logger.warning(f"Invalid JSON in {self.data_path}, using empty documents") | |
| return [] | |
| def _load_or_compute_embeddings(self) -> np.ndarray: | |
| """Load cached embeddings or compute new ones.""" | |
| embedding_cache_path = 'data/doc_embeddings.npy' | |
| if not self.documents: | |
| logger.info("No documents to embed, returning empty embeddings") | |
| return np.array([]) | |
| # Check for cached embeddings | |
| if os.path.exists(embedding_cache_path): | |
| try: | |
| embeddings = np.load(embedding_cache_path) | |
| if embeddings.shape[0] == len(self.documents): | |
| logger.info(f"Loaded {embeddings.shape[0]} cached embeddings from {embedding_cache_path}") | |
| return embeddings | |
| else: | |
| logger.warning(f"Cached embeddings shape mismatch, recomputing...") | |
| except Exception as e: | |
| logger.warning(f"Failed to load cached embeddings: {str(e)}, recomputing...") | |
| # Compute new embeddings | |
| texts = [doc['content'] for doc in self.documents] | |
| logger.info(f"Computing embeddings for {len(texts)} documents...") | |
| start_time = time.time() | |
| embeddings = self.model.encode(texts, batch_size=32, show_progress_bar=True) | |
| logger.info(f"Embedding {len(texts)} documents took {time.time() - start_time:.2f} seconds") | |
| # Cache embeddings | |
| try: | |
| os.makedirs('data', exist_ok=True) | |
| np.save(embedding_cache_path, embeddings) | |
| logger.info(f"Saved embeddings to {embedding_cache_path}") | |
| except Exception as e: | |
| logger.warning(f"Failed to save embeddings: {str(e)}") | |
| return embeddings | |
| def retrieve(self, query: str, top_k: int = 3) -> List[Dict]: | |
| """Retrieve the top-k most relevant documents for a given query.""" | |
| if not self.documents: | |
| logger.warning("No documents available for retrieval") | |
| return [] | |
| logger.info(f"Retrieving top {top_k} documents for query: {query}") | |
| query_embedding = self.model.encode(query) | |
| scores = np.dot(self.doc_embeddings, query_embedding) | |
| top_indices = np.argsort(scores)[-top_k:][::-1] | |
| results = [self.documents[i] for i in top_indices] | |
| logger.info(f"Retrieved {len(results)} documents") | |
| return results |