from __future__ import annotations from collections import OrderedDict from dataclasses import dataclass from typing import Any, Callable, Iterable, Hashable import torch @dataclass class _CacheEntry: value: Any size_bytes: int class TextEncoderCache: def __init__(self, max_size_mb: float = 100) -> None: self.max_size_bytes = int(max_size_mb * 1024 * 1024) self._entries: "OrderedDict[Hashable, _CacheEntry]" = OrderedDict() self._size_bytes = 0 def encode( self, encode_fn: Callable[[list[str]], list[Any]], prompts: Iterable[str] | str, device: torch.device | str | None = None, parallel: bool = False, cache_keys: Iterable[Hashable] | Hashable | None = None, ) -> list[Any]: if isinstance(prompts, str): prompts_list = [prompts] else: prompts_list = list(prompts) if not prompts_list: return [] if cache_keys is None: keys_list = prompts_list else: if len(prompts_list) == 1 and not isinstance(cache_keys, list): keys_list = [cache_keys] else: keys_list = list(cache_keys) if len(keys_list) != len(prompts_list): raise ValueError("cache_keys must match the number of prompts.") if not parallel: results: list[Any] = [] for prompt, cache_key in zip(prompts_list, keys_list): cached = self._entries.get(cache_key) if cached is not None: self._entries.move_to_end(cache_key) results.append(self._to_device(cached.value, device)) continue encoded = encode_fn([prompt]) if isinstance(encoded, (list, tuple)): if not encoded: raise ValueError("encode_fn returned empty embeddings.") encoded_item = encoded[0] else: encoded_item = encoded results.append(self._store(cache_key, encoded_item, device)) return results results = [None] * len(prompts_list) missing_prompts: list[str] = [] missing_indices: list[int] = [] missing_keys: list[Hashable] = [] for idx, (prompt, cache_key) in enumerate(zip(prompts_list, keys_list)): cached = self._entries.get(cache_key) if cached is None: missing_prompts.append(prompt) missing_indices.append(idx) missing_keys.append(cache_key) continue self._entries.move_to_end(cache_key) results[idx] = self._to_device(cached.value, device) if missing_prompts: encoded_batch = encode_fn(missing_prompts) if not isinstance(encoded_batch, list): encoded_batch = list(encoded_batch) if len(encoded_batch) != len(missing_prompts): raise ValueError("encode_fn returned unexpected number of embeddings.") for cache_key, idx, encoded in zip(missing_keys, missing_indices, encoded_batch): results[idx] = self._store(cache_key, encoded, device) return results def _store(self, cache_key: Hashable, encoded: Any, device: torch.device | str | None) -> Any: cached_value = self._detach_to_cpu(encoded) size_bytes = self._estimate_size_bytes(cached_value) if size_bytes <= self.max_size_bytes: existing = self._entries.pop(cache_key, None) if existing is not None: self._size_bytes -= existing.size_bytes self._entries[cache_key] = _CacheEntry(cached_value, size_bytes) self._size_bytes += size_bytes self._purge_if_needed() else: if cache_key in self._entries: self._entries.move_to_end(cache_key) return self._to_device(encoded, device) def _purge_if_needed(self) -> None: if self._size_bytes <= self.max_size_bytes: return while self._entries and self._size_bytes > self.max_size_bytes: _, entry = self._entries.popitem(last=False) self._size_bytes -= entry.size_bytes def _estimate_size_bytes(self, value: Any) -> int: if torch.is_tensor(value): return int(value.numel() * value.element_size()) if isinstance(value, dict): return sum(self._estimate_size_bytes(v) for v in value.values()) if isinstance(value, (list, tuple)): return sum(self._estimate_size_bytes(v) for v in value) return 0 def _detach_to_cpu(self, value: Any) -> Any: if torch.is_tensor(value): if value.device.type == "cpu": return value.detach() return value.detach().to("cpu") if isinstance(value, dict): return {k: self._detach_to_cpu(v) for k, v in value.items()} if isinstance(value, tuple): items = [self._detach_to_cpu(v) for v in value] if hasattr(value, "_fields"): return value.__class__(*items) return tuple(items) if isinstance(value, list): return [self._detach_to_cpu(v) for v in value] return value def _to_device(self, value: Any, device: torch.device | str | None) -> Any: if device is None: return value if torch.is_tensor(value): return value.to(device) if isinstance(value, dict): return {k: self._to_device(v, device) for k, v in value.items()} if isinstance(value, tuple): items = [self._to_device(v, device) for v in value] if hasattr(value, "_fields"): return value.__class__(*items) return tuple(items) if isinstance(value, list): return [self._to_device(v, device) for v in value] return value