| | from __future__ import annotations
|
| |
|
| | from collections import OrderedDict
|
| | from dataclasses import dataclass
|
| | from typing import Any, Callable, Iterable, Hashable
|
| |
|
| | import torch
|
| |
|
| |
|
| | @dataclass
|
| | class _CacheEntry:
|
| | value: Any
|
| | size_bytes: int
|
| |
|
| |
|
| | class TextEncoderCache:
|
| | def __init__(self, max_size_mb: float = 100) -> None:
|
| | self.max_size_bytes = int(max_size_mb * 1024 * 1024)
|
| | self._entries: "OrderedDict[Hashable, _CacheEntry]" = OrderedDict()
|
| | self._size_bytes = 0
|
| |
|
| | def encode(
|
| | self,
|
| | encode_fn: Callable[[list[str]], list[Any]],
|
| | prompts: Iterable[str] | str,
|
| | device: torch.device | str | None = None,
|
| | parallel: bool = False,
|
| | cache_keys: Iterable[Hashable] | Hashable | None = None,
|
| | ) -> list[Any]:
|
| | if isinstance(prompts, str):
|
| | prompts_list = [prompts]
|
| | else:
|
| | prompts_list = list(prompts)
|
| | if not prompts_list:
|
| | return []
|
| | if cache_keys is None:
|
| | keys_list = prompts_list
|
| | else:
|
| | if len(prompts_list) == 1 and not isinstance(cache_keys, list):
|
| | keys_list = [cache_keys]
|
| | else:
|
| | keys_list = list(cache_keys)
|
| | if len(keys_list) != len(prompts_list):
|
| | raise ValueError("cache_keys must match the number of prompts.")
|
| |
|
| | if not parallel:
|
| | results: list[Any] = []
|
| | for prompt, cache_key in zip(prompts_list, keys_list):
|
| | cached = self._entries.get(cache_key)
|
| | if cached is not None:
|
| | self._entries.move_to_end(cache_key)
|
| | results.append(self._to_device(cached.value, device))
|
| | continue
|
| | encoded = encode_fn([prompt])
|
| | if isinstance(encoded, (list, tuple)):
|
| | if not encoded:
|
| | raise ValueError("encode_fn returned empty embeddings.")
|
| | encoded_item = encoded[0]
|
| | else:
|
| | encoded_item = encoded
|
| | results.append(self._store(cache_key, encoded_item, device))
|
| | return results
|
| |
|
| | results = [None] * len(prompts_list)
|
| | missing_prompts: list[str] = []
|
| | missing_indices: list[int] = []
|
| | missing_keys: list[Hashable] = []
|
| |
|
| | for idx, (prompt, cache_key) in enumerate(zip(prompts_list, keys_list)):
|
| | cached = self._entries.get(cache_key)
|
| | if cached is None:
|
| | missing_prompts.append(prompt)
|
| | missing_indices.append(idx)
|
| | missing_keys.append(cache_key)
|
| | continue
|
| | self._entries.move_to_end(cache_key)
|
| | results[idx] = self._to_device(cached.value, device)
|
| |
|
| | if missing_prompts:
|
| | encoded_batch = encode_fn(missing_prompts)
|
| | if not isinstance(encoded_batch, list):
|
| | encoded_batch = list(encoded_batch)
|
| | if len(encoded_batch) != len(missing_prompts):
|
| | raise ValueError("encode_fn returned unexpected number of embeddings.")
|
| | for cache_key, idx, encoded in zip(missing_keys, missing_indices, encoded_batch):
|
| | results[idx] = self._store(cache_key, encoded, device)
|
| |
|
| | return results
|
| |
|
| | def _store(self, cache_key: Hashable, encoded: Any, device: torch.device | str | None) -> Any:
|
| | cached_value = self._detach_to_cpu(encoded)
|
| | size_bytes = self._estimate_size_bytes(cached_value)
|
| | if size_bytes <= self.max_size_bytes:
|
| | existing = self._entries.pop(cache_key, None)
|
| | if existing is not None:
|
| | self._size_bytes -= existing.size_bytes
|
| | self._entries[cache_key] = _CacheEntry(cached_value, size_bytes)
|
| | self._size_bytes += size_bytes
|
| | self._purge_if_needed()
|
| | else:
|
| | if cache_key in self._entries:
|
| | self._entries.move_to_end(cache_key)
|
| | return self._to_device(encoded, device)
|
| |
|
| | def _purge_if_needed(self) -> None:
|
| | if self._size_bytes <= self.max_size_bytes:
|
| | return
|
| | while self._entries and self._size_bytes > self.max_size_bytes:
|
| | _, entry = self._entries.popitem(last=False)
|
| | self._size_bytes -= entry.size_bytes
|
| |
|
| | def _estimate_size_bytes(self, value: Any) -> int:
|
| | if torch.is_tensor(value):
|
| | return int(value.numel() * value.element_size())
|
| | if isinstance(value, dict):
|
| | return sum(self._estimate_size_bytes(v) for v in value.values())
|
| | if isinstance(value, (list, tuple)):
|
| | return sum(self._estimate_size_bytes(v) for v in value)
|
| | return 0
|
| |
|
| | def _detach_to_cpu(self, value: Any) -> Any:
|
| | if torch.is_tensor(value):
|
| | if value.device.type == "cpu":
|
| | return value.detach()
|
| | return value.detach().to("cpu")
|
| | if isinstance(value, dict):
|
| | return {k: self._detach_to_cpu(v) for k, v in value.items()}
|
| | if isinstance(value, tuple):
|
| | items = [self._detach_to_cpu(v) for v in value]
|
| | if hasattr(value, "_fields"):
|
| | return value.__class__(*items)
|
| | return tuple(items)
|
| | if isinstance(value, list):
|
| | return [self._detach_to_cpu(v) for v in value]
|
| | return value
|
| |
|
| | def _to_device(self, value: Any, device: torch.device | str | None) -> Any:
|
| | if device is None:
|
| | return value
|
| | if torch.is_tensor(value):
|
| | return value.to(device)
|
| | if isinstance(value, dict):
|
| | return {k: self._to_device(v, device) for k, v in value.items()}
|
| | if isinstance(value, tuple):
|
| | items = [self._to_device(v, device) for v in value]
|
| | if hasattr(value, "_fields"):
|
| | return value.__class__(*items)
|
| | return tuple(items)
|
| | if isinstance(value, list):
|
| | return [self._to_device(v, device) for v in value]
|
| | return value
|
| |
|