File size: 8,884 Bytes
2a5255e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import contextlib
import contextvars
import hashlib
import json
import os
import threading
import time
from collections import OrderedDict
from dataclasses import dataclass
from typing import Iterable, Optional

import numpy as np
import torch

from kimodo.sanitize import sanitize_texts

_ACTIVE_SESSION = contextvars.ContextVar("kimodo_demo_active_session", default=None)


@dataclass
class CacheStats:
    hits: int = 0
    misses: int = 0
    disk_hits: int = 0


class EmbeddingCache:
    """Disk-backed text embedding cache with a small in-memory LRU."""

    def __init__(
        self,
        *,
        model_name: str,
        encoder_id: str,
        base_dir: Optional[str] = None,
        max_mem_entries: int = 128,
    ) -> None:
        cache_root = base_dir or os.environ.get(
            "kimodo_EMBED_CACHE_DIR",
            os.path.join("~", ".cache", "kimodo_demo", "embeddings"),
        )
        self.base_dir = os.path.expanduser(cache_root)
        self.model_name = model_name
        self.encoder_id = encoder_id
        self.max_mem_entries = max_mem_entries
        self.stats = CacheStats()

        self._lock = threading.Lock()
        self._mem_cache: OrderedDict[str, np.ndarray] = OrderedDict()
        self._index = {}
        self._index_loaded = False

    def _model_dir(self) -> str:
        return os.path.join(self.base_dir, self.model_name)

    def _index_path(self) -> str:
        return os.path.join(self._model_dir(), "index.json")

    def _prewarm_marker_path(self, key: str) -> str:
        return os.path.join(self._model_dir(), f"prewarm_{key}.json")

    def has_prewarm_marker(self, key: str) -> bool:
        return os.path.exists(self._prewarm_marker_path(key))

    def write_prewarm_marker(self, key: str, *, prompt_count: int) -> None:
        os.makedirs(self._model_dir(), exist_ok=True)
        payload = {"prompt_count": prompt_count, "updated_at": time.time()}
        tmp_path = f"{self._prewarm_marker_path(key)}.tmp"
        with open(tmp_path, "w", encoding="utf-8") as f:
            json.dump(payload, f)
        os.replace(tmp_path, self._prewarm_marker_path(key))

    def _load_index(self) -> None:
        if self._index_loaded:
            return
        index_path = self._index_path()
        if os.path.exists(index_path):
            try:
                with open(index_path, "r", encoding="utf-8") as f:
                    self._index = json.load(f)
            except json.JSONDecodeError:
                self._index = {}
        self._index_loaded = True

    def _save_index(self) -> None:
        os.makedirs(self._model_dir(), exist_ok=True)
        tmp_path = f"{self._index_path()}.tmp"
        with open(tmp_path, "w", encoding="utf-8") as f:
            json.dump(self._index, f)
        os.replace(tmp_path, self._index_path())

    def _make_key(self, text: str) -> str:
        key_src = f"{self.model_name}|{self.encoder_id}|{text}"
        return hashlib.sha256(key_src.encode("utf-8")).hexdigest()

    def _entry_path(self, key: str) -> str:
        return os.path.join(self._model_dir(), f"{key}.npy")

    def _mem_get(self, key: str) -> Optional[np.ndarray]:
        if key in self._mem_cache:
            self._mem_cache.move_to_end(key)
            return self._mem_cache[key]
        return None

    def _mem_put(self, key: str, value: np.ndarray) -> None:
        self._mem_cache[key] = value
        self._mem_cache.move_to_end(key)
        while len(self._mem_cache) > self.max_mem_entries:
            self._mem_cache.popitem(last=False)

    def _disk_load(self, key: str) -> Optional[np.ndarray]:
        path = self._entry_path(key)
        if not os.path.exists(path):
            return None
        try:
            return np.load(path)
        except Exception:
            return None

    def _disk_save(self, key: str, value: np.ndarray) -> None:
        os.makedirs(self._model_dir(), exist_ok=True)
        np.save(self._entry_path(key), value)
        self._index[key] = {
            "length": int(value.shape[0]),
            "dtype": str(value.dtype),
            "updated_at": time.time(),
        }

    def _maybe_use_session_cache(self, texts: list[str]):
        session = _ACTIVE_SESSION.get()
        if session is None:
            return None
        if session.last_prompt_texts == texts and session.last_prompt_embeddings is not None:
            return session.last_prompt_embeddings, session.last_prompt_lengths
        return None

    def _update_session_cache(self, texts: list[str], tensor: torch.Tensor, lengths: list[int]) -> None:
        session = _ACTIVE_SESSION.get()
        if session is None:
            return
        session.last_prompt_texts = texts
        session.last_prompt_embeddings = tensor
        session.last_prompt_lengths = lengths

    def get_or_encode(self, texts: Iterable[str], encoder):
        if isinstance(texts, str):
            texts = [texts]
        texts = sanitize_texts(list(texts))
        if len(texts) == 0:
            empty = torch.empty()
            return empty, []

        session_cache = self._maybe_use_session_cache(texts)
        if session_cache is not None:
            return session_cache

        arrays: list[Optional[np.ndarray]] = [None] * len(texts)
        lengths: list[int] = [0] * len(texts)
        misses: list[tuple[int, str, str]] = []

        with self._lock:
            self._load_index()
            for idx, text in enumerate(texts):
                key = self._make_key(text)
                cached = self._mem_get(key)
                if cached is not None:
                    arrays[idx] = cached
                    lengths[idx] = cached.shape[0]
                    self.stats.hits += 1
                    continue

                cached = self._disk_load(key)
                if cached is not None:
                    arrays[idx] = cached
                    lengths[idx] = cached.shape[0]
                    self._mem_put(key, cached)
                    self.stats.disk_hits += 1
                    continue

                misses.append((idx, text, key))
                self.stats.misses += 1

        if misses:
            miss_texts = [text for _, text, _ in misses]
            miss_tensor, miss_lengths = encoder(miss_texts)
            miss_tensor = miss_tensor.detach().cpu()
            miss_tensor_np = miss_tensor.numpy()

            with self._lock:
                self._load_index()
                for miss_idx, length in enumerate(miss_lengths):
                    idx, _text, key = misses[miss_idx]
                    arr = miss_tensor_np[miss_idx, :length].copy()
                    arrays[idx] = arr
                    lengths[idx] = int(length)
                    self._mem_put(key, arr)
                    self._disk_save(key, arr)
                self._save_index()

        max_len = max(lengths) if lengths else 0
        feat_dim = arrays[0].shape[-1] if arrays[0] is not None else 0
        dtype = arrays[0].dtype if arrays[0] is not None else np.float32
        padded = np.zeros((len(texts), max_len, feat_dim), dtype=dtype)
        for idx, arr in enumerate(arrays):
            if arr is None:
                continue
            padded[idx, : arr.shape[0]] = arr

        result = torch.from_numpy(padded)
        self._update_session_cache(texts, result, lengths)
        return result, lengths


class CachedTextEncoder:
    """Wrapper around a text encoder to add disk-backed caching."""

    def __init__(self, encoder, *, model_name: str, base_dir: Optional[str] = None):
        self.encoder = encoder
        self.model_name = model_name
        encoder_id = f"{type(encoder).__name__}"
        self.cache = EmbeddingCache(model_name=model_name, encoder_id=encoder_id, base_dir=base_dir)

    def __call__(self, texts):
        return self.cache.get_or_encode(texts, self.encoder)

    def prewarm(self, texts) -> None:
        if isinstance(texts, str):
            texts = [texts]
        texts = sanitize_texts(list(texts))
        prewarm_key = hashlib.sha256("|".join(texts).encode("utf-8")).hexdigest()
        if self.cache.has_prewarm_marker(prewarm_key):
            return
        self.cache.get_or_encode(texts, self.encoder)
        self.cache.write_prewarm_marker(prewarm_key, prompt_count=len(texts))

    def to(self, device=None, dtype=None):
        if hasattr(self.encoder, "to"):
            self.encoder.to(device=device, dtype=dtype)
        return self

    @contextlib.contextmanager
    def session_context(self, session):
        token = _ACTIVE_SESSION.set(session)
        try:
            yield
        finally:
            _ACTIVE_SESSION.reset(token)

    def __getattr__(self, name):
        return getattr(self.encoder, name)