Spaces:
Running
Running
| from typing import List, Tuple, Dict, Literal | |
| from pathlib import Path | |
| import hashlib | |
| import json | |
| import numpy as np | |
| import pandas as pd | |
| from sentence_transformers import SentenceTransformer | |
| from ..base import BaseRetriever | |
| _model_cache: Dict[str, SentenceTransformer] = {} | |
| class DenseRetriever(BaseRetriever): | |
| """Dense retrieval using sentence transformers.""" | |
| def __init__(self, model_name: str, db_path: str, index_config: Literal["naive", "v1"]): | |
| super().__init__(db_path, index_config) | |
| self.model_name = model_name | |
| if model_name not in _model_cache: | |
| model = SentenceTransformer(model_name) | |
| model.half() # fp16: halves model memory footprint | |
| _model_cache[model_name] = model | |
| self.model = _model_cache[model_name] | |
| self.corpus_embeddings = None | |
| self.embeddings_path = self._default_embeddings_path() | |
| self.load_index() if self._embeddings_exist() else self.build_index() | |
| def _default_embeddings_path(self) -> str: | |
| model_safe = self.model_name.replace("/", "_") | |
| payload = { | |
| "model": self.model_name, | |
| "columns": self.index_config, | |
| "rows": len(pd.read_csv(self.db_path)) | |
| } | |
| # stable SHA-256 hash based on payload | |
| hash_str = hashlib.sha256(json.dumps(payload, sort_keys=True).encode()).hexdigest()[:16] | |
| return f"data/embeddings/{model_safe}_{hash_str}.npz" | |
| def _embeddings_exist(self) -> bool: | |
| return Path(self.embeddings_path).exists() | |
| def _store_index(self): | |
| path = Path(self.embeddings_path) | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| np.savez_compressed( | |
| path, | |
| embeddings=self.corpus_embeddings, | |
| agent_ids=np.array(self.agent_ids), | |
| model_name=self.model_name, | |
| index_config=self.index_config | |
| ) | |
| def load_index(self): | |
| """ | |
| Use precomputed embeddings. | |
| """ | |
| data = np.load(self.embeddings_path, allow_pickle=True) | |
| self.corpus_embeddings = data['embeddings'] | |
| self.agent_ids = data['agent_ids'].tolist() | |
| # verify metadata | |
| stored_model = str(data['model_name']) | |
| stored_index_config = data['index_config'].tolist() | |
| if stored_model != self.model_name: | |
| print(f"WARNING: Loaded embeddings from {stored_model}, but using {self.model_name}") | |
| if stored_index_config != self.index_config: | |
| raise ValueError(f"Index configuration mismatch! Stored: {stored_index_config}, Expected: {self.index_config}") | |
| def build_index(self): | |
| """ | |
| Build your embeddings. | |
| """ | |
| self.agent_ids, self.corpus = self.indexing_func[self.index_config]() | |
| self.corpus_embeddings = self.model.encode( | |
| self.corpus, | |
| show_progress_bar=True, | |
| convert_to_numpy=True, | |
| ) | |
| self._store_index() # avoid re-building | |
| def retrieve(self, query: str, top_k: int = 10): | |
| """ | |
| Retrieve using dot product. | |
| """ | |
| # NOTE: .encode() internally normalises our vectors | |
| query_embedding = self.model.encode([query], convert_to_numpy=True)[0] | |
| scores = np.dot(self.corpus_embeddings, query_embedding) | |
| top_indices = np.argsort(scores)[-top_k:][::-1] | |
| return [(self.agent_ids[idx], float(scores[idx])) for idx in top_indices] | |