File size: 2,231 Bytes
34b531b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from __future__ import annotations

import hashlib
import json
import sqlite3
from pathlib import Path


class EmbeddingCache:
    def __init__(self, path: Path) -> None:
        self.path = path
        self.path.parent.mkdir(parents=True, exist_ok=True)
        self._connection = sqlite3.connect(self.path)
        self._connection.execute(
            """
            CREATE TABLE IF NOT EXISTS embeddings (
                cache_key TEXT PRIMARY KEY,
                vector_json TEXT NOT NULL,
                created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
            )
            """
        )
        self._connection.commit()

    def close(self) -> None:
        self._connection.close()

    def get_many(self, keys: list[str]) -> dict[str, list[float]]:
        if not keys:
            return {}

        found: dict[str, list[float]] = {}
        for start in range(0, len(keys), 900):
            chunk = keys[start : start + 900]
            placeholders = ",".join("?" for _ in chunk)
            rows = self._connection.execute(
                f"SELECT cache_key, vector_json FROM embeddings WHERE cache_key IN ({placeholders})",
                chunk,
            ).fetchall()
            for cache_key, vector_json in rows:
                found[str(cache_key)] = [float(value) for value in json.loads(vector_json)]
        return found

    def set_many(self, values: dict[str, list[float]]) -> None:
        if not values:
            return

        rows = [
            (cache_key, json.dumps(vector, separators=(",", ":")))
            for cache_key, vector in values.items()
        ]
        self._connection.executemany(
            """
            INSERT OR REPLACE INTO embeddings (cache_key, vector_json)
            VALUES (?, ?)
            """,
            rows,
        )
        self._connection.commit()


def embedding_cache_key(
    text: str,
    *,
    provider: str,
    model: str,
    dim: int,
    api_url: str,
) -> str:
    payload = "\n".join(
        [
            f"provider={provider}",
            f"model={model}",
            f"dim={dim}",
            f"api_url={api_url}",
            text,
        ]
    )
    return hashlib.sha256(payload.encode("utf-8")).hexdigest()