from typing import List, Tuple, Dict, Literal from pathlib import Path import hashlib import json import numpy as np import pandas as pd from sentence_transformers import SentenceTransformer from ..base import BaseRetriever _model_cache: Dict[str, SentenceTransformer] = {} class DenseRetriever(BaseRetriever): """Dense retrieval using sentence transformers.""" def __init__(self, model_name: str, db_path: str, index_config: Literal["naive", "v1"]): super().__init__(db_path, index_config) self.model_name = model_name if model_name not in _model_cache: model = SentenceTransformer(model_name) model.half() # fp16: halves model memory footprint _model_cache[model_name] = model self.model = _model_cache[model_name] self.corpus_embeddings = None self.embeddings_path = self._default_embeddings_path() self.load_index() if self._embeddings_exist() else self.build_index() def _default_embeddings_path(self) -> str: model_safe = self.model_name.replace("/", "_") payload = { "model": self.model_name, "columns": self.index_config, "rows": len(pd.read_csv(self.db_path)) } # stable SHA-256 hash based on payload hash_str = hashlib.sha256(json.dumps(payload, sort_keys=True).encode()).hexdigest()[:16] return f"data/embeddings/{model_safe}_{hash_str}.npz" def _embeddings_exist(self) -> bool: return Path(self.embeddings_path).exists() def _store_index(self): path = Path(self.embeddings_path) path.parent.mkdir(parents=True, exist_ok=True) np.savez_compressed( path, embeddings=self.corpus_embeddings, agent_ids=np.array(self.agent_ids), model_name=self.model_name, index_config=self.index_config ) def load_index(self): """ Use precomputed embeddings. """ data = np.load(self.embeddings_path, allow_pickle=True) self.corpus_embeddings = data['embeddings'] self.agent_ids = data['agent_ids'].tolist() # verify metadata stored_model = str(data['model_name']) stored_index_config = data['index_config'].tolist() if stored_model != self.model_name: print(f"WARNING: Loaded embeddings from {stored_model}, but using {self.model_name}") if stored_index_config != self.index_config: raise ValueError(f"Index configuration mismatch! Stored: {stored_index_config}, Expected: {self.index_config}") def build_index(self): """ Build your embeddings. """ self.agent_ids, self.corpus = self.indexing_func[self.index_config]() self.corpus_embeddings = self.model.encode( self.corpus, show_progress_bar=True, convert_to_numpy=True, ) self._store_index() # avoid re-building def retrieve(self, query: str, top_k: int = 10): """ Retrieve using dot product. """ # NOTE: .encode() internally normalises our vectors query_embedding = self.model.encode([query], convert_to_numpy=True)[0] scores = np.dot(self.corpus_embeddings, query_embedding) top_indices = np.argsort(scores)[-top_k:][::-1] return [(self.agent_ids[idx], float(scores[idx])) for idx in top_indices]