|
|
|
|
|
import os
|
|
|
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
|
|
|
|
|
import hashlib
|
|
|
import threading
|
|
|
import numpy as np
|
|
|
import pickle
|
|
|
import torch
|
|
|
from pathlib import Path
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
from classifier.utils import DEVICE
|
|
|
|
|
|
try:
|
|
|
import faiss
|
|
|
_HAS_FAISS = True
|
|
|
except Exception:
|
|
|
_HAS_FAISS = False
|
|
|
|
|
|
def _chunks(lst, n):
|
|
|
for i in range(0, len(lst), n):
|
|
|
yield lst[i:i+n]
|
|
|
|
|
|
def _compute_cache_key(docs, model_name):
|
|
|
"""Compute a hash key for caching based on documents and model."""
|
|
|
|
|
|
doc_ids = "".join([d.id for d in docs])
|
|
|
content = f"{model_name}:{doc_ids}"
|
|
|
return hashlib.md5(content.encode()).hexdigest()
|
|
|
|
|
|
class DenseIndex:
|
|
|
def __init__(self, docs, model_name="sentence-transformers/embeddinggemma-300m-medical",
|
|
|
batch_size=64, embedding_model=None, cache_dir=".cache/embeddings"):
|
|
|
self.docs = docs
|
|
|
self.batch_size = batch_size
|
|
|
self.cache_dir = cache_dir
|
|
|
|
|
|
|
|
|
self.lock = threading.Lock()
|
|
|
self.ready_count = 0
|
|
|
self.emb_batches = []
|
|
|
|
|
|
torch.set_num_threads(1)
|
|
|
if embedding_model:
|
|
|
self.model = embedding_model
|
|
|
self.device = self.model.device
|
|
|
actual_model_name = getattr(self.model, 'model_card_data', {}).get('base_model', model_name)
|
|
|
if hasattr(self.model, '_model_card_vars') and 'model_id' in self.model._model_card_vars:
|
|
|
actual_model_name = self.model._model_card_vars['model_id']
|
|
|
else:
|
|
|
self.model = SentenceTransformer(model_name, device=DEVICE)
|
|
|
self.device = DEVICE
|
|
|
actual_model_name = model_name
|
|
|
|
|
|
self.cache_key = _compute_cache_key(docs, actual_model_name)
|
|
|
self.cache_path = Path(cache_dir) / f"{self.cache_key}.pkl"
|
|
|
|
|
|
|
|
|
if _HAS_FAISS:
|
|
|
|
|
|
|
|
|
self.index = None
|
|
|
else:
|
|
|
self.index = None
|
|
|
|
|
|
|
|
|
self.ingest_thread = threading.Thread(target=self._ingest_embeddings, daemon=True)
|
|
|
self.ingest_thread.start()
|
|
|
|
|
|
def _generate_embeddings(self):
|
|
|
"""Yields batches of embeddings from cache or computation."""
|
|
|
texts = [d.text for d in self.docs]
|
|
|
|
|
|
|
|
|
if self.cache_path.exists():
|
|
|
print(f"Loading embeddings from cache: {self.cache_path}")
|
|
|
try:
|
|
|
with open(self.cache_path, 'rb') as f:
|
|
|
full_emb = pickle.load(f)
|
|
|
print(f"✓ Loaded {len(full_emb)} cached embeddings")
|
|
|
|
|
|
yield full_emb
|
|
|
return
|
|
|
except Exception as e:
|
|
|
print(f"Cache load failed: {e}, recomputing...")
|
|
|
|
|
|
|
|
|
partial_cache_path = self.cache_path.parent / f"{self.cache_path.stem}.partial.pkl"
|
|
|
start_index = 0
|
|
|
existing_embs = []
|
|
|
|
|
|
if partial_cache_path.exists():
|
|
|
try:
|
|
|
with open(partial_cache_path, 'rb') as f:
|
|
|
existing_embs = pickle.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for batch in existing_embs:
|
|
|
yield batch
|
|
|
|
|
|
start_index = sum(len(e) for e in existing_embs)
|
|
|
except Exception as e:
|
|
|
existing_embs = []
|
|
|
start_index = 0
|
|
|
|
|
|
|
|
|
texts_to_process = texts[start_index:]
|
|
|
if not texts_to_process:
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.inference_mode():
|
|
|
total_processed = start_index
|
|
|
total_batches = (len(texts) + self.batch_size - 1) // self.batch_size
|
|
|
start_batch = len(existing_embs)
|
|
|
|
|
|
for i, part in enumerate(_chunks(texts_to_process, self.batch_size), 1):
|
|
|
part_emb = self.model.encode(
|
|
|
part,
|
|
|
batch_size=self.batch_size,
|
|
|
normalize_embeddings=True,
|
|
|
convert_to_numpy=True,
|
|
|
show_progress_bar=False,
|
|
|
device=self.device,
|
|
|
)
|
|
|
batch_emb = part_emb.astype(np.float32)
|
|
|
yield batch_emb
|
|
|
|
|
|
existing_embs.append(batch_emb)
|
|
|
total_processed += len(part)
|
|
|
|
|
|
|
|
|
with open(partial_cache_path, 'wb') as f:
|
|
|
pickle.dump(existing_embs, f)
|
|
|
|
|
|
def _ingest_embeddings(self):
|
|
|
"""Background thread to ingest embeddings from generator."""
|
|
|
all_embs = []
|
|
|
|
|
|
for batch_emb in self._generate_embeddings():
|
|
|
with self.lock:
|
|
|
if _HAS_FAISS:
|
|
|
if self.index is None:
|
|
|
d = batch_emb.shape[1]
|
|
|
self.index = faiss.IndexFlatIP(d)
|
|
|
self.index.add(batch_emb)
|
|
|
|
|
|
|
|
|
self.emb_batches.append(batch_emb)
|
|
|
self.ready_count += len(batch_emb)
|
|
|
|
|
|
all_embs.append(batch_emb)
|
|
|
|
|
|
|
|
|
full_emb = np.vstack(all_embs).astype(np.float32)
|
|
|
|
|
|
|
|
|
self.cache_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
with open(self.cache_path, 'wb') as f:
|
|
|
pickle.dump(full_emb, f)
|
|
|
print(f"✓ Saved embeddings to cache: {self.cache_path}")
|
|
|
|
|
|
|
|
|
partial_cache_path = self.cache_path.parent / f"{self.cache_path.stem}.partial.pkl"
|
|
|
if partial_cache_path.exists():
|
|
|
partial_cache_path.unlink()
|
|
|
|
|
|
def search(self, query: str, k: int = 50):
|
|
|
qv = self.model.encode(
|
|
|
[query],
|
|
|
normalize_embeddings=True,
|
|
|
convert_to_numpy=True,
|
|
|
show_progress_bar=False,
|
|
|
device=self.device,
|
|
|
).astype(np.float32)[0]
|
|
|
|
|
|
with self.lock:
|
|
|
current_count = self.ready_count
|
|
|
if current_count == 0:
|
|
|
print("Warning: Index not yet initialized, returning empty results.")
|
|
|
return []
|
|
|
|
|
|
|
|
|
if _HAS_FAISS and self.index is not None:
|
|
|
|
|
|
D, I = self.index.search(qv.reshape(1, -1), min(k, current_count))
|
|
|
return [(self.docs[int(i)], float(D[0][j])) for j, i in enumerate(I[0]) if i != -1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
curr_emb = np.vstack(self.emb_batches)
|
|
|
|
|
|
sims = curr_emb @ qv
|
|
|
effective_k = min(k, len(sims))
|
|
|
|
|
|
if effective_k >= len(sims):
|
|
|
order = np.argsort(-sims)
|
|
|
else:
|
|
|
idx = np.argpartition(-sims, kth=effective_k-1)[:effective_k]
|
|
|
order = idx[np.argsort(-sims[idx])]
|
|
|
|
|
|
return [(self.docs[int(i)], float(sims[int(i)])) for i in order]
|
|
|
|
|
|
def get_progress(self):
|
|
|
"""Returns (current_count, total_count) of indexed documents."""
|
|
|
with self.lock:
|
|
|
return self.ready_count, len(self.docs)
|
|
|
|