File size: 2,187 Bytes
8a37e0a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 | from queue import Queue
from typing import TYPE_CHECKING, Optional, TypeVar
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
T = TypeVar("T")
if TYPE_CHECKING:
from invokeai.app.services.invoker import Invoker
class ObjectSerializerForwardCache(ObjectSerializerBase[T]):
"""
Provides a LRU cache for an instance of `ObjectSerializerBase`.
Saving an object to the cache always writes through to the underlying storage.
"""
def __init__(self, underlying_storage: ObjectSerializerBase[T], max_cache_size: int = 20):
super().__init__()
self._underlying_storage = underlying_storage
self._cache: dict[str, T] = {}
self._cache_ids = Queue[str]()
self._max_cache_size = max_cache_size
def start(self, invoker: "Invoker") -> None:
self._invoker = invoker
start_op = getattr(self._underlying_storage, "start", None)
if callable(start_op):
start_op(invoker)
def stop(self, invoker: "Invoker") -> None:
self._invoker = invoker
stop_op = getattr(self._underlying_storage, "stop", None)
if callable(stop_op):
stop_op(invoker)
def load(self, name: str) -> T:
cache_item = self._get_cache(name)
if cache_item is not None:
return cache_item
obj = self._underlying_storage.load(name)
self._set_cache(name, obj)
return obj
def save(self, obj: T) -> str:
name = self._underlying_storage.save(obj)
self._set_cache(name, obj)
return name
def delete(self, name: str) -> None:
self._underlying_storage.delete(name)
if name in self._cache:
del self._cache[name]
self._on_deleted(name)
def _get_cache(self, name: str) -> Optional[T]:
return None if name not in self._cache else self._cache[name]
def _set_cache(self, name: str, data: T):
if name not in self._cache:
self._cache[name] = data
self._cache_ids.put(name)
if self._cache_ids.qsize() > self._max_cache_size:
self._cache.pop(self._cache_ids.get())
|