| import itertools |
| from typing import Optional |
|
|
| class TaggedCache: |
| def __init__(self, tag_settings: Optional[dict]=None): |
| self._tag_settings = tag_settings or {} |
| self._data = {} |
|
|
| def __getitem__(self, key): |
| for tag_data in self._data.values(): |
| if key in tag_data: |
| return tag_data[key] |
| raise KeyError(f'Key `{key}` does not exist') |
|
|
| def __setitem__(self, key, value: tuple): |
| |
|
|
| |
| for tag_data in self._data.values(): |
| if key in tag_data: |
| tag_data.pop(key, None) |
| break |
|
|
| tag = value[0] |
| if tag not in self._data: |
|
|
| try: |
| from cachetools import LRUCache |
|
|
| default_size = 20 |
| if 'ckpt' in tag: |
| default_size = 5 |
| elif tag in ['latent', 'image']: |
| default_size = 100 |
|
|
| self._data[tag] = LRUCache(maxsize=self._tag_settings.get(tag, default_size)) |
|
|
| except (ImportError, ModuleNotFoundError): |
| |
| self._data[tag] = {} |
| self._data[tag][key] = value |
|
|
| def __delitem__(self, key): |
| for tag_data in self._data.values(): |
| if key in tag_data: |
| del tag_data[key] |
| return |
| raise KeyError(f'Key `{key}` does not exist') |
|
|
| def __contains__(self, key): |
| return any(key in tag_data for tag_data in self._data.values()) |
|
|
| def items(self): |
| yield from itertools.chain(*map(lambda x :x.items(), self._data.values())) |
|
|
| def get(self, key, default=None): |
| """D.get(k[,d]) -> D[k] if k in D, else d. d defaults to None.""" |
| for tag_data in self._data.values(): |
| if key in tag_data: |
| return tag_data[key] |
| return default |
|
|
| def clear(self): |
| |
| self._data = {} |
|
|
| cache_settings = {} |
| cache = TaggedCache(cache_settings) |
| cache_count = {} |
|
|
| def update_cache(k, tag, v): |
| cache[k] = (tag, v) |
| cnt = cache_count.get(k) |
| if cnt is None: |
| cnt = 0 |
| cache_count[k] = cnt |
| else: |
| cache_count[k] += 1 |
| def remove_cache(key): |
| if key == '*': |
| cache.clear() |
| cache_count.clear() |
| elif key in cache: |
| del cache[key] |
| cache_count.pop(key, None) |
| else: |
| print(f"invalid {key}") |
|
|