Wan2GP / shared /utils /text_encoder_cache.py
attong39's picture
Upload folder using huggingface_hub
f523f14 verified
from __future__ import annotations
from collections import OrderedDict
from dataclasses import dataclass
from typing import Any, Callable, Iterable, Hashable
import torch
@dataclass
class _CacheEntry:
value: Any
size_bytes: int
class TextEncoderCache:
def __init__(self, max_size_mb: float = 100) -> None:
self.max_size_bytes = int(max_size_mb * 1024 * 1024)
self._entries: "OrderedDict[Hashable, _CacheEntry]" = OrderedDict()
self._size_bytes = 0
def encode(
self,
encode_fn: Callable[[list[str]], list[Any]],
prompts: Iterable[str] | str,
device: torch.device | str | None = None,
parallel: bool = False,
cache_keys: Iterable[Hashable] | Hashable | None = None,
) -> list[Any]:
if isinstance(prompts, str):
prompts_list = [prompts]
else:
prompts_list = list(prompts)
if not prompts_list:
return []
if cache_keys is None:
keys_list = prompts_list
else:
if len(prompts_list) == 1 and not isinstance(cache_keys, list):
keys_list = [cache_keys]
else:
keys_list = list(cache_keys)
if len(keys_list) != len(prompts_list):
raise ValueError("cache_keys must match the number of prompts.")
if not parallel:
results: list[Any] = []
for prompt, cache_key in zip(prompts_list, keys_list):
cached = self._entries.get(cache_key)
if cached is not None:
self._entries.move_to_end(cache_key)
results.append(self._to_device(cached.value, device))
continue
encoded = encode_fn([prompt])
if isinstance(encoded, (list, tuple)):
if not encoded:
raise ValueError("encode_fn returned empty embeddings.")
encoded_item = encoded[0]
else:
encoded_item = encoded
results.append(self._store(cache_key, encoded_item, device))
return results
results = [None] * len(prompts_list)
missing_prompts: list[str] = []
missing_indices: list[int] = []
missing_keys: list[Hashable] = []
for idx, (prompt, cache_key) in enumerate(zip(prompts_list, keys_list)):
cached = self._entries.get(cache_key)
if cached is None:
missing_prompts.append(prompt)
missing_indices.append(idx)
missing_keys.append(cache_key)
continue
self._entries.move_to_end(cache_key)
results[idx] = self._to_device(cached.value, device)
if missing_prompts:
encoded_batch = encode_fn(missing_prompts)
if not isinstance(encoded_batch, list):
encoded_batch = list(encoded_batch)
if len(encoded_batch) != len(missing_prompts):
raise ValueError("encode_fn returned unexpected number of embeddings.")
for cache_key, idx, encoded in zip(missing_keys, missing_indices, encoded_batch):
results[idx] = self._store(cache_key, encoded, device)
return results
def _store(self, cache_key: Hashable, encoded: Any, device: torch.device | str | None) -> Any:
cached_value = self._detach_to_cpu(encoded)
size_bytes = self._estimate_size_bytes(cached_value)
if size_bytes <= self.max_size_bytes:
existing = self._entries.pop(cache_key, None)
if existing is not None:
self._size_bytes -= existing.size_bytes
self._entries[cache_key] = _CacheEntry(cached_value, size_bytes)
self._size_bytes += size_bytes
self._purge_if_needed()
else:
if cache_key in self._entries:
self._entries.move_to_end(cache_key)
return self._to_device(encoded, device)
def _purge_if_needed(self) -> None:
if self._size_bytes <= self.max_size_bytes:
return
while self._entries and self._size_bytes > self.max_size_bytes:
_, entry = self._entries.popitem(last=False)
self._size_bytes -= entry.size_bytes
def _estimate_size_bytes(self, value: Any) -> int:
if torch.is_tensor(value):
return int(value.numel() * value.element_size())
if isinstance(value, dict):
return sum(self._estimate_size_bytes(v) for v in value.values())
if isinstance(value, (list, tuple)):
return sum(self._estimate_size_bytes(v) for v in value)
return 0
def _detach_to_cpu(self, value: Any) -> Any:
if torch.is_tensor(value):
if value.device.type == "cpu":
return value.detach()
return value.detach().to("cpu")
if isinstance(value, dict):
return {k: self._detach_to_cpu(v) for k, v in value.items()}
if isinstance(value, tuple):
items = [self._detach_to_cpu(v) for v in value]
if hasattr(value, "_fields"):
return value.__class__(*items)
return tuple(items)
if isinstance(value, list):
return [self._detach_to_cpu(v) for v in value]
return value
def _to_device(self, value: Any, device: torch.device | str | None) -> Any:
if device is None:
return value
if torch.is_tensor(value):
return value.to(device)
if isinstance(value, dict):
return {k: self._to_device(v, device) for k, v in value.items()}
if isinstance(value, tuple):
items = [self._to_device(v, device) for v in value]
if hasattr(value, "_fields"):
return value.__class__(*items)
return tuple(items)
if isinstance(value, list):
return [self._to_device(v, device) for v in value]
return value