File size: 2,231 Bytes
34b531b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 | from __future__ import annotations
import hashlib
import json
import sqlite3
from pathlib import Path
class EmbeddingCache:
def __init__(self, path: Path) -> None:
self.path = path
self.path.parent.mkdir(parents=True, exist_ok=True)
self._connection = sqlite3.connect(self.path)
self._connection.execute(
"""
CREATE TABLE IF NOT EXISTS embeddings (
cache_key TEXT PRIMARY KEY,
vector_json TEXT NOT NULL,
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
)
"""
)
self._connection.commit()
def close(self) -> None:
self._connection.close()
def get_many(self, keys: list[str]) -> dict[str, list[float]]:
if not keys:
return {}
found: dict[str, list[float]] = {}
for start in range(0, len(keys), 900):
chunk = keys[start : start + 900]
placeholders = ",".join("?" for _ in chunk)
rows = self._connection.execute(
f"SELECT cache_key, vector_json FROM embeddings WHERE cache_key IN ({placeholders})",
chunk,
).fetchall()
for cache_key, vector_json in rows:
found[str(cache_key)] = [float(value) for value in json.loads(vector_json)]
return found
def set_many(self, values: dict[str, list[float]]) -> None:
if not values:
return
rows = [
(cache_key, json.dumps(vector, separators=(",", ":")))
for cache_key, vector in values.items()
]
self._connection.executemany(
"""
INSERT OR REPLACE INTO embeddings (cache_key, vector_json)
VALUES (?, ?)
""",
rows,
)
self._connection.commit()
def embedding_cache_key(
text: str,
*,
provider: str,
model: str,
dim: int,
api_url: str,
) -> str:
payload = "\n".join(
[
f"provider={provider}",
f"model={model}",
f"dim={dim}",
f"api_url={api_url}",
text,
]
)
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
|