| from typing import Any, Optional, Tuple, List |
| import hashlib |
| import json |
| import logging |
| import threading |
|
|
| |
| from comfy_api.latest._caching import CacheProvider, CacheContext, CacheValue |
|
|
| _logger = logging.getLogger(__name__) |
|
|
|
|
| _providers: List[CacheProvider] = [] |
| _providers_lock = threading.Lock() |
| _providers_snapshot: Tuple[CacheProvider, ...] = () |
|
|
|
|
| def register_cache_provider(provider: CacheProvider) -> None: |
| """Register an external cache provider. Providers are called in registration order.""" |
| global _providers_snapshot |
| with _providers_lock: |
| if provider in _providers: |
| _logger.warning(f"Provider {provider.__class__.__name__} already registered") |
| return |
| _providers.append(provider) |
| _providers_snapshot = tuple(_providers) |
| _logger.debug(f"Registered cache provider: {provider.__class__.__name__}") |
|
|
|
|
| def unregister_cache_provider(provider: CacheProvider) -> None: |
| global _providers_snapshot |
| with _providers_lock: |
| try: |
| _providers.remove(provider) |
| _providers_snapshot = tuple(_providers) |
| _logger.debug(f"Unregistered cache provider: {provider.__class__.__name__}") |
| except ValueError: |
| _logger.warning(f"Provider {provider.__class__.__name__} was not registered") |
|
|
|
|
| def _get_cache_providers() -> Tuple[CacheProvider, ...]: |
| return _providers_snapshot |
|
|
|
|
| def _has_cache_providers() -> bool: |
| return bool(_providers_snapshot) |
|
|
|
|
| def _clear_cache_providers() -> None: |
| global _providers_snapshot |
| with _providers_lock: |
| _providers.clear() |
| _providers_snapshot = () |
|
|
|
|
| def _canonicalize(obj: Any) -> Any: |
| |
| |
| |
| |
| if isinstance(obj, frozenset): |
| return ("__frozenset__", sorted( |
| [_canonicalize(item) for item in obj], |
| key=lambda x: json.dumps(x, sort_keys=True) |
| )) |
| elif isinstance(obj, set): |
| return ("__set__", sorted( |
| [_canonicalize(item) for item in obj], |
| key=lambda x: json.dumps(x, sort_keys=True) |
| )) |
| elif isinstance(obj, tuple): |
| return ("__tuple__", [_canonicalize(item) for item in obj]) |
| elif isinstance(obj, list): |
| return [_canonicalize(item) for item in obj] |
| elif isinstance(obj, dict): |
| return {"__dict__": sorted( |
| [[_canonicalize(k), _canonicalize(v)] for k, v in obj.items()], |
| key=lambda x: json.dumps(x, sort_keys=True) |
| )} |
| elif isinstance(obj, (int, float, str, bool, type(None))): |
| return (type(obj).__name__, obj) |
| elif isinstance(obj, bytes): |
| return ("__bytes__", obj.hex()) |
| else: |
| raise ValueError(f"Cannot canonicalize type: {type(obj).__name__}") |
|
|
|
|
| def _serialize_cache_key(cache_key: Any) -> Optional[str]: |
| |
| |
| try: |
| canonical = _canonicalize(cache_key) |
| json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':')) |
| return hashlib.sha256(json_str.encode('utf-8')).hexdigest() |
| except Exception as e: |
| _logger.warning(f"Failed to serialize cache key: {e}") |
| return None |
|
|
|
|
| def _contains_self_unequal(obj: Any) -> bool: |
| |
| |
| try: |
| if not (obj == obj): |
| return True |
| except Exception: |
| return True |
| if isinstance(obj, (frozenset, tuple, list, set)): |
| return any(_contains_self_unequal(item) for item in obj) |
| if isinstance(obj, dict): |
| return any(_contains_self_unequal(k) or _contains_self_unequal(v) for k, v in obj.items()) |
| if hasattr(obj, 'value'): |
| return _contains_self_unequal(obj.value) |
| return False |
|
|
|
|
| def _estimate_value_size(value: CacheValue) -> int: |
| try: |
| import torch |
| except ImportError: |
| return 0 |
|
|
| total = 0 |
|
|
| def estimate(obj): |
| nonlocal total |
| if isinstance(obj, torch.Tensor): |
| total += obj.numel() * obj.element_size() |
| elif isinstance(obj, dict): |
| for v in obj.values(): |
| estimate(v) |
| elif isinstance(obj, (list, tuple)): |
| for item in obj: |
| estimate(item) |
|
|
| for output in value.outputs: |
| estimate(output) |
| return total |
|
|