| from queue import Queue | |
| from typing import TYPE_CHECKING, Optional, TypeVar | |
| from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase | |
| T = TypeVar("T") | |
| if TYPE_CHECKING: | |
| from invokeai.app.services.invoker import Invoker | |
| class ObjectSerializerForwardCache(ObjectSerializerBase[T]): | |
| """ | |
| Provides a LRU cache for an instance of `ObjectSerializerBase`. | |
| Saving an object to the cache always writes through to the underlying storage. | |
| """ | |
| def __init__(self, underlying_storage: ObjectSerializerBase[T], max_cache_size: int = 20): | |
| super().__init__() | |
| self._underlying_storage = underlying_storage | |
| self._cache: dict[str, T] = {} | |
| self._cache_ids = Queue[str]() | |
| self._max_cache_size = max_cache_size | |
| def start(self, invoker: "Invoker") -> None: | |
| self._invoker = invoker | |
| start_op = getattr(self._underlying_storage, "start", None) | |
| if callable(start_op): | |
| start_op(invoker) | |
| def stop(self, invoker: "Invoker") -> None: | |
| self._invoker = invoker | |
| stop_op = getattr(self._underlying_storage, "stop", None) | |
| if callable(stop_op): | |
| stop_op(invoker) | |
| def load(self, name: str) -> T: | |
| cache_item = self._get_cache(name) | |
| if cache_item is not None: | |
| return cache_item | |
| obj = self._underlying_storage.load(name) | |
| self._set_cache(name, obj) | |
| return obj | |
| def save(self, obj: T) -> str: | |
| name = self._underlying_storage.save(obj) | |
| self._set_cache(name, obj) | |
| return name | |
| def delete(self, name: str) -> None: | |
| self._underlying_storage.delete(name) | |
| if name in self._cache: | |
| del self._cache[name] | |
| self._on_deleted(name) | |
| def _get_cache(self, name: str) -> Optional[T]: | |
| return None if name not in self._cache else self._cache[name] | |
| def _set_cache(self, name: str, data: T): | |
| if name not in self._cache: | |
| self._cache[name] = data | |
| self._cache_ids.put(name) | |
| if self._cache_ids.qsize() > self._max_cache_size: | |
| self._cache.pop(self._cache_ids.get()) | |