"""Embedding wrapper around sentence-transformers with disk cache.""" from __future__ import annotations import hashlib from pathlib import Path from typing import Any import numpy as np class Embedder: """Embeds text using sentence-transformers with optional disk cache. Accepts any object with an encode() method so tests can inject a mock without downloading the 80MB model. """ def __init__( self, model: Any = None, model_name: str = "all-MiniLM-L6-v2", cache_dir: str = ".cache/embeddings", ) -> None: self._model_name = model_name if model is not None: self._model = model else: from sentence_transformers import SentenceTransformer self._model = SentenceTransformer(model_name) self._cache_dir = Path(cache_dir) self._cache_dir.mkdir(parents=True, exist_ok=True) def _cache_key(self, text: str) -> str: """Cache key scoped to model name + text content.""" raw = f"{self._model_name}:{text}" return hashlib.sha256(raw.encode()).hexdigest() def embed(self, text: str) -> np.ndarray: """Embed a single text string. Returns shape (384,) normalized vector.""" cache_path = self._cache_dir / f"{self._cache_key(text)}.npy" if cache_path.exists(): vec = np.load(cache_path) return np.asarray(vec, dtype=np.float32) vec = self._model.encode([text], normalize_embeddings=True)[0] vec = np.asarray(vec, dtype=np.float32) np.save(cache_path, vec) return vec def embed_batch(self, texts: list[str]) -> np.ndarray: """Embed multiple texts. Returns shape (n, 384) normalized matrix.""" results = [] uncached_texts: list[str] = [] uncached_indices: list[int] = [] for i, text in enumerate(texts): cache_path = self._cache_dir / f"{self._cache_key(text)}.npy" if cache_path.exists(): results.append((i, np.load(cache_path))) else: uncached_texts.append(text) uncached_indices.append(i) results.append((i, None)) if uncached_texts: vecs = self._model.encode(uncached_texts, normalize_embeddings=True) vecs = np.asarray(vecs, dtype=np.float32) for j, idx in enumerate(uncached_indices): vec = vecs[j] # Save to cache cache_path = self._cache_dir / f"{self._cache_key(uncached_texts[j])}.npy" np.save(cache_path, vec) # Update results for k, (ri, rv) in enumerate(results): if ri == idx: results[k] = (ri, vec) break # Sort by original index and stack results.sort(key=lambda x: x[0]) return np.stack([r[1] for r in results]) # type: ignore[misc]