from __future__ import annotations import hashlib import json import sqlite3 from pathlib import Path class EmbeddingCache: def __init__(self, path: Path) -> None: self.path = path self.path.parent.mkdir(parents=True, exist_ok=True) self._connection = sqlite3.connect(self.path) self._connection.execute( """ CREATE TABLE IF NOT EXISTS embeddings ( cache_key TEXT PRIMARY KEY, vector_json TEXT NOT NULL, created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP ) """ ) self._connection.commit() def close(self) -> None: self._connection.close() def get_many(self, keys: list[str]) -> dict[str, list[float]]: if not keys: return {} found: dict[str, list[float]] = {} for start in range(0, len(keys), 900): chunk = keys[start : start + 900] placeholders = ",".join("?" for _ in chunk) rows = self._connection.execute( f"SELECT cache_key, vector_json FROM embeddings WHERE cache_key IN ({placeholders})", chunk, ).fetchall() for cache_key, vector_json in rows: found[str(cache_key)] = [float(value) for value in json.loads(vector_json)] return found def set_many(self, values: dict[str, list[float]]) -> None: if not values: return rows = [ (cache_key, json.dumps(vector, separators=(",", ":"))) for cache_key, vector in values.items() ] self._connection.executemany( """ INSERT OR REPLACE INTO embeddings (cache_key, vector_json) VALUES (?, ?) """, rows, ) self._connection.commit() def embedding_cache_key( text: str, *, provider: str, model: str, dim: int, api_url: str, ) -> str: payload = "\n".join( [ f"provider={provider}", f"model={model}", f"dim={dim}", f"api_url={api_url}", text, ] ) return hashlib.sha256(payload.encode("utf-8")).hexdigest()