# retriever/index_dense.py import os os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") import hashlib import threading import numpy as np import pickle import torch from pathlib import Path from sentence_transformers import SentenceTransformer from classifier.utils import DEVICE try: import faiss # type: ignore _HAS_FAISS = True except Exception: _HAS_FAISS = False def _chunks(lst, n): for i in range(0, len(lst), n): yield lst[i:i+n] def _compute_cache_key(docs, model_name): """Compute a hash key for caching based on documents and model.""" # Create a hash from document IDs/texts and model name doc_ids = "".join([d.id for d in docs]) content = f"{model_name}:{doc_ids}" return hashlib.md5(content.encode()).hexdigest() class DenseIndex: def __init__(self, docs, model_name="sentence-transformers/embeddinggemma-300m-medical", batch_size=64, embedding_model=None, cache_dir=".cache/embeddings"): self.docs = docs self.batch_size = batch_size self.cache_dir = cache_dir # Thread safety self.lock = threading.Lock() self.ready_count = 0 self.emb_batches = [] # List of numpy arrays for fallback torch.set_num_threads(1) if embedding_model: self.model = embedding_model self.device = self.model.device actual_model_name = getattr(self.model, 'model_card_data', {}).get('base_model', model_name) if hasattr(self.model, '_model_card_vars') and 'model_id' in self.model._model_card_vars: actual_model_name = self.model._model_card_vars['model_id'] else: self.model = SentenceTransformer(model_name, device=DEVICE) self.device = DEVICE actual_model_name = model_name self.cache_key = _compute_cache_key(docs, actual_model_name) self.cache_path = Path(cache_dir) / f"{self.cache_key}.pkl" # Initialize index structure if _HAS_FAISS: # We need to know dimension to init FAISS. # We'll init it when the first batch arrives or if we load full cache. self.index = None else: self.index = None # Start background ingestion self.ingest_thread = threading.Thread(target=self._ingest_embeddings, daemon=True) self.ingest_thread.start() def _generate_embeddings(self): """Yields batches of embeddings from cache or computation.""" texts = [d.text for d in self.docs] # 1. Try full cache first if self.cache_path.exists(): print(f"Loading embeddings from cache: {self.cache_path}") try: with open(self.cache_path, 'rb') as f: full_emb = pickle.load(f) print(f"✓ Loaded {len(full_emb)} cached embeddings") # Yield as a single large batch yield full_emb return except Exception as e: print(f"Cache load failed: {e}, recomputing...") # 2. Partial cache logic partial_cache_path = self.cache_path.parent / f"{self.cache_path.stem}.partial.pkl" start_index = 0 existing_embs = [] if partial_cache_path.exists(): try: with open(partial_cache_path, 'rb') as f: existing_embs = pickle.load(f) # Yield existing chunks # We assume existing_embs is a list of batches from previous run # But wait, previous implementation saved list of batches. # Let's verify if it saved list of batches or vstacked array. # Previous impl: pickle.dump(embs, f) where embs is list of arrays. for batch in existing_embs: yield batch start_index = sum(len(e) for e in existing_embs) except Exception as e: existing_embs = [] start_index = 0 # 3. Compute remaining texts_to_process = texts[start_index:] if not texts_to_process: return # We need to keep track of all embs (existing + new) to save partial/full cache # But `existing_embs` might be large. # We will append new batches to `existing_embs` locally to save partials. with torch.inference_mode(): total_processed = start_index total_batches = (len(texts) + self.batch_size - 1) // self.batch_size start_batch = len(existing_embs) for i, part in enumerate(_chunks(texts_to_process, self.batch_size), 1): part_emb = self.model.encode( part, batch_size=self.batch_size, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=False, device=self.device, ) batch_emb = part_emb.astype(np.float32) yield batch_emb existing_embs.append(batch_emb) total_processed += len(part) # Save partial with open(partial_cache_path, 'wb') as f: pickle.dump(existing_embs, f) def _ingest_embeddings(self): """Background thread to ingest embeddings from generator.""" all_embs = [] for batch_emb in self._generate_embeddings(): with self.lock: if _HAS_FAISS: if self.index is None: d = batch_emb.shape[1] self.index = faiss.IndexFlatIP(d) self.index.add(batch_emb) # We also keep track for fallback or saving self.emb_batches.append(batch_emb) self.ready_count += len(batch_emb) all_embs.append(batch_emb) # Finalize full_emb = np.vstack(all_embs).astype(np.float32) # Save full cache self.cache_path.parent.mkdir(parents=True, exist_ok=True) with open(self.cache_path, 'wb') as f: pickle.dump(full_emb, f) print(f"✓ Saved embeddings to cache: {self.cache_path}") # Cleanup partial partial_cache_path = self.cache_path.parent / f"{self.cache_path.stem}.partial.pkl" if partial_cache_path.exists(): partial_cache_path.unlink() def search(self, query: str, k: int = 50): qv = self.model.encode( [query], normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=False, device=self.device, ).astype(np.float32)[0] with self.lock: current_count = self.ready_count if current_count == 0: print("Warning: Index not yet initialized, returning empty results.") return [] # If we have partial data, we search it. if _HAS_FAISS and self.index is not None: # FAISS index is updated incrementally D, I = self.index.search(qv.reshape(1, -1), min(k, current_count)) return [(self.docs[int(i)], float(D[0][j])) for j, i in enumerate(I[0]) if i != -1] # NumPy fallback # We might have multiple batches, need to stack them for search # Optimization: cache the stacked version if it hasn't changed? # For now, just stack what we have. curr_emb = np.vstack(self.emb_batches) sims = curr_emb @ qv effective_k = min(k, len(sims)) if effective_k >= len(sims): order = np.argsort(-sims) else: idx = np.argpartition(-sims, kth=effective_k-1)[:effective_k] order = idx[np.argsort(-sims[idx])] return [(self.docs[int(i)], float(sims[int(i)])) for i in order] def get_progress(self): """Returns (current_count, total_count) of indexed documents.""" with self.lock: return self.ready_count, len(self.docs)