agentbench / agent_bench /rag /embedder.py
Nomearod's picture
fix: recursive chunker applies overlap, cache key includes model name
77bdc95
"""Embedding wrapper around sentence-transformers with disk cache."""
from __future__ import annotations
import hashlib
from pathlib import Path
from typing import Any
import numpy as np
class Embedder:
"""Embeds text using sentence-transformers with optional disk cache.
Accepts any object with an encode() method so tests can inject a mock
without downloading the 80MB model.
"""
def __init__(
self,
model: Any = None,
model_name: str = "all-MiniLM-L6-v2",
cache_dir: str = ".cache/embeddings",
) -> None:
self._model_name = model_name
if model is not None:
self._model = model
else:
from sentence_transformers import SentenceTransformer
self._model = SentenceTransformer(model_name)
self._cache_dir = Path(cache_dir)
self._cache_dir.mkdir(parents=True, exist_ok=True)
def _cache_key(self, text: str) -> str:
"""Cache key scoped to model name + text content."""
raw = f"{self._model_name}:{text}"
return hashlib.sha256(raw.encode()).hexdigest()
def embed(self, text: str) -> np.ndarray:
"""Embed a single text string. Returns shape (384,) normalized vector."""
cache_path = self._cache_dir / f"{self._cache_key(text)}.npy"
if cache_path.exists():
vec = np.load(cache_path)
return np.asarray(vec, dtype=np.float32)
vec = self._model.encode([text], normalize_embeddings=True)[0]
vec = np.asarray(vec, dtype=np.float32)
np.save(cache_path, vec)
return vec
def embed_batch(self, texts: list[str]) -> np.ndarray:
"""Embed multiple texts. Returns shape (n, 384) normalized matrix."""
results = []
uncached_texts: list[str] = []
uncached_indices: list[int] = []
for i, text in enumerate(texts):
cache_path = self._cache_dir / f"{self._cache_key(text)}.npy"
if cache_path.exists():
results.append((i, np.load(cache_path)))
else:
uncached_texts.append(text)
uncached_indices.append(i)
results.append((i, None))
if uncached_texts:
vecs = self._model.encode(uncached_texts, normalize_embeddings=True)
vecs = np.asarray(vecs, dtype=np.float32)
for j, idx in enumerate(uncached_indices):
vec = vecs[j]
# Save to cache
cache_path = self._cache_dir / f"{self._cache_key(uncached_texts[j])}.npy"
np.save(cache_path, vec)
# Update results
for k, (ri, rv) in enumerate(results):
if ri == idx:
results[k] = (ri, vec)
break
# Sort by original index and stack
results.sort(key=lambda x: x[0])
return np.stack([r[1] for r in results]) # type: ignore[misc]