Spaces:
Runtime error
Runtime error
File size: 4,164 Bytes
28e499f 34fbe97 c06718f 4f86dd4 2c3faf7 4f86dd4 202c49b a876100 4f86dd4 28e499f a876100 40c75fc c06718f 4f86dd4 28e499f 4f86dd4 28e499f 4f86dd4 28e499f 4f86dd4 28e499f c06718f 40c75fc 28e499f 4f86dd4 28e499f 40c75fc a876100 40c75fc 4f86dd4 c06718f 4f86dd4 28e499f 4f86dd4 28e499f 4f86dd4 a876100 4f86dd4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import os
import json
import numpy as np
from sentence_transformers import SentenceTransformer
from typing import List, Dict
import logging
import time
logger = logging.getLogger(__name__)
class DocumentRetriever:
def __init__(self, model_name='all-MiniLM-L6-v2', data_path='data/rupeia_document.json', cache_folder=None):
"""
Initialize the DocumentRetriever with a SentenceTransformer model.
Args:
model_name (str): Name of the SentenceTransformer model (default: 'all-MiniLM-L6-v2').
data_path (str): Path to the document JSON file (default: 'data/rupeia_document.json').
cache_folder (str, optional): Directory to cache model files (default: None).
"""
logger.info(f"Initializing DocumentRetriever with model: {model_name}, cache_folder: {cache_folder}")
try:
self.model = SentenceTransformer(model_name, cache_folder=cache_folder)
except Exception as e:
logger.error(f"Failed to load SentenceTransformer model: {str(e)}")
raise
self.data_path = data_path
self.documents = self._load_documents()
self.doc_embeddings = self._load_or_compute_embeddings()
def _load_documents(self) -> List[Dict]:
"""Load documents from the JSON file."""
try:
with open(self.data_path, 'r') as f:
documents = json.load(f)
logger.info(f"Loaded {len(documents)} documents from {self.data_path}")
return documents
except FileNotFoundError:
logger.warning(f"Data file not found at {self.data_path}, using empty documents")
return []
except json.JSONDecodeError:
logger.warning(f"Invalid JSON in {self.data_path}, using empty documents")
return []
def _load_or_compute_embeddings(self) -> np.ndarray:
"""Load cached embeddings or compute new ones."""
embedding_cache_path = 'data/doc_embeddings.npy'
if not self.documents:
logger.info("No documents to embed, returning empty embeddings")
return np.array([])
# Check for cached embeddings
if os.path.exists(embedding_cache_path):
try:
embeddings = np.load(embedding_cache_path)
if embeddings.shape[0] == len(self.documents):
logger.info(f"Loaded {embeddings.shape[0]} cached embeddings from {embedding_cache_path}")
return embeddings
else:
logger.warning(f"Cached embeddings shape mismatch, recomputing...")
except Exception as e:
logger.warning(f"Failed to load cached embeddings: {str(e)}, recomputing...")
# Compute new embeddings
texts = [doc['content'] for doc in self.documents]
logger.info(f"Computing embeddings for {len(texts)} documents...")
start_time = time.time()
embeddings = self.model.encode(texts, batch_size=32, show_progress_bar=True)
logger.info(f"Embedding {len(texts)} documents took {time.time() - start_time:.2f} seconds")
# Cache embeddings
try:
os.makedirs('data', exist_ok=True)
np.save(embedding_cache_path, embeddings)
logger.info(f"Saved embeddings to {embedding_cache_path}")
except Exception as e:
logger.warning(f"Failed to save embeddings: {str(e)}")
return embeddings
def retrieve(self, query: str, top_k: int = 3) -> List[Dict]:
"""Retrieve the top-k most relevant documents for a given query."""
if not self.documents:
logger.warning("No documents available for retrieval")
return []
logger.info(f"Retrieving top {top_k} documents for query: {query}")
query_embedding = self.model.encode(query)
scores = np.dot(self.doc_embeddings, query_embedding)
top_indices = np.argsort(scores)[-top_k:][::-1]
results = [self.documents[i] for i in top_indices]
logger.info(f"Retrieved {len(results)} documents")
return results |