File size: 2,187 Bytes
8a37e0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from queue import Queue
from typing import TYPE_CHECKING, Optional, TypeVar

from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase

T = TypeVar("T")

if TYPE_CHECKING:
    from invokeai.app.services.invoker import Invoker


class ObjectSerializerForwardCache(ObjectSerializerBase[T]):
    """
    Provides a LRU cache for an instance of `ObjectSerializerBase`.
    Saving an object to the cache always writes through to the underlying storage.
    """

    def __init__(self, underlying_storage: ObjectSerializerBase[T], max_cache_size: int = 20):
        super().__init__()
        self._underlying_storage = underlying_storage
        self._cache: dict[str, T] = {}
        self._cache_ids = Queue[str]()
        self._max_cache_size = max_cache_size

    def start(self, invoker: "Invoker") -> None:
        self._invoker = invoker
        start_op = getattr(self._underlying_storage, "start", None)
        if callable(start_op):
            start_op(invoker)

    def stop(self, invoker: "Invoker") -> None:
        self._invoker = invoker
        stop_op = getattr(self._underlying_storage, "stop", None)
        if callable(stop_op):
            stop_op(invoker)

    def load(self, name: str) -> T:
        cache_item = self._get_cache(name)
        if cache_item is not None:
            return cache_item

        obj = self._underlying_storage.load(name)
        self._set_cache(name, obj)
        return obj

    def save(self, obj: T) -> str:
        name = self._underlying_storage.save(obj)
        self._set_cache(name, obj)
        return name

    def delete(self, name: str) -> None:
        self._underlying_storage.delete(name)
        if name in self._cache:
            del self._cache[name]
        self._on_deleted(name)

    def _get_cache(self, name: str) -> Optional[T]:
        return None if name not in self._cache else self._cache[name]

    def _set_cache(self, name: str, data: T):
        if name not in self._cache:
            self._cache[name] = data
            self._cache_ids.put(name)
            if self._cache_ids.qsize() > self._max_cache_size:
                self._cache.pop(self._cache_ids.get())