| | import sys |
| | from abc import ABC, abstractmethod |
| | from typing import ( |
| | Optional, |
| | Sequence, |
| | Tuple, |
| | ) |
| | from collections import OrderedDict |
| |
|
| | import diskcache |
| |
|
| | import llama_cpp.llama |
| |
|
| | from .llama_types import * |
| |
|
| |
|
| | class BaseLlamaCache(ABC): |
| | """Base cache class for a llama.cpp model.""" |
| |
|
| | def __init__(self, capacity_bytes: int = (2 << 30)): |
| | self.capacity_bytes = capacity_bytes |
| |
|
| | @property |
| | @abstractmethod |
| | def cache_size(self) -> int: |
| | raise NotImplementedError |
| |
|
| | def _find_longest_prefix_key( |
| | self, |
| | key: Tuple[int, ...], |
| | ) -> Optional[Tuple[int, ...]]: |
| | pass |
| |
|
| | @abstractmethod |
| | def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState": |
| | raise NotImplementedError |
| |
|
| | @abstractmethod |
| | def __contains__(self, key: Sequence[int]) -> bool: |
| | raise NotImplementedError |
| |
|
| | @abstractmethod |
| | def __setitem__( |
| | self, key: Sequence[int], value: "llama_cpp.llama.LlamaState" |
| | ) -> None: |
| | raise NotImplementedError |
| |
|
| |
|
| | class LlamaRAMCache(BaseLlamaCache): |
| | """Cache for a llama.cpp model using RAM.""" |
| |
|
| | def __init__(self, capacity_bytes: int = (2 << 30)): |
| | super().__init__(capacity_bytes) |
| | self.capacity_bytes = capacity_bytes |
| | self.cache_state: OrderedDict[ |
| | Tuple[int, ...], "llama_cpp.llama.LlamaState" |
| | ] = OrderedDict() |
| |
|
| | @property |
| | def cache_size(self): |
| | return sum([state.llama_state_size for state in self.cache_state.values()]) |
| |
|
| | def _find_longest_prefix_key( |
| | self, |
| | key: Tuple[int, ...], |
| | ) -> Optional[Tuple[int, ...]]: |
| | min_len = 0 |
| | min_key = None |
| | keys = ( |
| | (k, llama_cpp.llama.Llama.longest_token_prefix(k, key)) |
| | for k in self.cache_state.keys() |
| | ) |
| | for k, prefix_len in keys: |
| | if prefix_len > min_len: |
| | min_len = prefix_len |
| | min_key = k |
| | return min_key |
| |
|
| | def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState": |
| | key = tuple(key) |
| | _key = self._find_longest_prefix_key(key) |
| | if _key is None: |
| | raise KeyError("Key not found") |
| | value = self.cache_state[_key] |
| | self.cache_state.move_to_end(_key) |
| | return value |
| |
|
| | def __contains__(self, key: Sequence[int]) -> bool: |
| | return self._find_longest_prefix_key(tuple(key)) is not None |
| |
|
| | def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"): |
| | key = tuple(key) |
| | if key in self.cache_state: |
| | del self.cache_state[key] |
| | self.cache_state[key] = value |
| | while self.cache_size > self.capacity_bytes and len(self.cache_state) > 0: |
| | self.cache_state.popitem(last=False) |
| |
|
| |
|
| | |
| | LlamaCache = LlamaRAMCache |
| |
|
| |
|
| | class LlamaDiskCache(BaseLlamaCache): |
| | """Cache for a llama.cpp model using disk.""" |
| |
|
| | def __init__( |
| | self, cache_dir: str = ".cache/llama_cache", capacity_bytes: int = (2 << 30) |
| | ): |
| | super().__init__(capacity_bytes) |
| | self.cache = diskcache.Cache(cache_dir) |
| |
|
| | @property |
| | def cache_size(self): |
| | return int(self.cache.volume()) |
| |
|
| | def _find_longest_prefix_key( |
| | self, |
| | key: Tuple[int, ...], |
| | ) -> Optional[Tuple[int, ...]]: |
| | min_len = 0 |
| | min_key: Optional[Tuple[int, ...]] = None |
| | for k in self.cache.iterkeys(): |
| | prefix_len = llama_cpp.llama.Llama.longest_token_prefix(k, key) |
| | if prefix_len > min_len: |
| | min_len = prefix_len |
| | min_key = k |
| | return min_key |
| |
|
| | def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState": |
| | key = tuple(key) |
| | _key = self._find_longest_prefix_key(key) |
| | if _key is None: |
| | raise KeyError("Key not found") |
| | value: "llama_cpp.llama.LlamaState" = self.cache.pop(_key) |
| | |
| | |
| | |
| | return value |
| |
|
| | def __contains__(self, key: Sequence[int]) -> bool: |
| | return self._find_longest_prefix_key(tuple(key)) is not None |
| |
|
| | def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"): |
| | print("LlamaDiskCache.__setitem__: called", file=sys.stderr) |
| | key = tuple(key) |
| | if key in self.cache: |
| | print("LlamaDiskCache.__setitem__: delete", file=sys.stderr) |
| | del self.cache[key] |
| | self.cache[key] = value |
| | print("LlamaDiskCache.__setitem__: set", file=sys.stderr) |
| | while self.cache_size > self.capacity_bytes and len(self.cache) > 0: |
| | key_to_remove = next(iter(self.cache)) |
| | del self.cache[key_to_remove] |
| | print("LlamaDiskCache.__setitem__: trim", file=sys.stderr) |
| |
|