|
|
import sys |
|
|
from abc import ABC, abstractmethod |
|
|
from typing import ( |
|
|
Optional, |
|
|
Sequence, |
|
|
Tuple, |
|
|
) |
|
|
from collections import OrderedDict |
|
|
|
|
|
import diskcache |
|
|
|
|
|
import llama_cpp.llama |
|
|
|
|
|
from .llama_types import * |
|
|
|
|
|
|
|
|
class BaseLlamaCache(ABC): |
|
|
"""Base cache class for a llama.cpp model.""" |
|
|
|
|
|
def __init__(self, capacity_bytes: int = (2 << 30)): |
|
|
self.capacity_bytes = capacity_bytes |
|
|
|
|
|
@property |
|
|
@abstractmethod |
|
|
def cache_size(self) -> int: |
|
|
raise NotImplementedError |
|
|
|
|
|
def _find_longest_prefix_key( |
|
|
self, |
|
|
key: Tuple[int, ...], |
|
|
) -> Optional[Tuple[int, ...]]: |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState": |
|
|
raise NotImplementedError |
|
|
|
|
|
@abstractmethod |
|
|
def __contains__(self, key: Sequence[int]) -> bool: |
|
|
raise NotImplementedError |
|
|
|
|
|
@abstractmethod |
|
|
def __setitem__( |
|
|
self, key: Sequence[int], value: "llama_cpp.llama.LlamaState" |
|
|
) -> None: |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
class LlamaRAMCache(BaseLlamaCache): |
|
|
"""Cache for a llama.cpp model using RAM.""" |
|
|
|
|
|
def __init__(self, capacity_bytes: int = (2 << 30)): |
|
|
super().__init__(capacity_bytes) |
|
|
self.capacity_bytes = capacity_bytes |
|
|
self.cache_state: OrderedDict[ |
|
|
Tuple[int, ...], "llama_cpp.llama.LlamaState" |
|
|
] = OrderedDict() |
|
|
|
|
|
@property |
|
|
def cache_size(self): |
|
|
return sum([state.llama_state_size for state in self.cache_state.values()]) |
|
|
|
|
|
def _find_longest_prefix_key( |
|
|
self, |
|
|
key: Tuple[int, ...], |
|
|
) -> Optional[Tuple[int, ...]]: |
|
|
min_len = 0 |
|
|
min_key = None |
|
|
keys = ( |
|
|
(k, llama_cpp.llama.Llama.longest_token_prefix(k, key)) |
|
|
for k in self.cache_state.keys() |
|
|
) |
|
|
for k, prefix_len in keys: |
|
|
if prefix_len > min_len: |
|
|
min_len = prefix_len |
|
|
min_key = k |
|
|
return min_key |
|
|
|
|
|
def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState": |
|
|
key = tuple(key) |
|
|
_key = self._find_longest_prefix_key(key) |
|
|
if _key is None: |
|
|
raise KeyError("Key not found") |
|
|
value = self.cache_state[_key] |
|
|
self.cache_state.move_to_end(_key) |
|
|
return value |
|
|
|
|
|
def __contains__(self, key: Sequence[int]) -> bool: |
|
|
return self._find_longest_prefix_key(tuple(key)) is not None |
|
|
|
|
|
def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"): |
|
|
key = tuple(key) |
|
|
if key in self.cache_state: |
|
|
del self.cache_state[key] |
|
|
self.cache_state[key] = value |
|
|
while self.cache_size > self.capacity_bytes and len(self.cache_state) > 0: |
|
|
self.cache_state.popitem(last=False) |
|
|
|
|
|
|
|
|
|
|
|
LlamaCache = LlamaRAMCache |
|
|
|
|
|
|
|
|
class LlamaDiskCache(BaseLlamaCache): |
|
|
"""Cache for a llama.cpp model using disk.""" |
|
|
|
|
|
def __init__( |
|
|
self, cache_dir: str = ".cache/llama_cache", capacity_bytes: int = (2 << 30) |
|
|
): |
|
|
super().__init__(capacity_bytes) |
|
|
self.cache = diskcache.Cache(cache_dir) |
|
|
|
|
|
@property |
|
|
def cache_size(self): |
|
|
return int(self.cache.volume()) |
|
|
|
|
|
def _find_longest_prefix_key( |
|
|
self, |
|
|
key: Tuple[int, ...], |
|
|
) -> Optional[Tuple[int, ...]]: |
|
|
min_len = 0 |
|
|
min_key: Optional[Tuple[int, ...]] = None |
|
|
for k in self.cache.iterkeys(): |
|
|
prefix_len = llama_cpp.llama.Llama.longest_token_prefix(k, key) |
|
|
if prefix_len > min_len: |
|
|
min_len = prefix_len |
|
|
min_key = k |
|
|
return min_key |
|
|
|
|
|
def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState": |
|
|
key = tuple(key) |
|
|
_key = self._find_longest_prefix_key(key) |
|
|
if _key is None: |
|
|
raise KeyError("Key not found") |
|
|
value: "llama_cpp.llama.LlamaState" = self.cache.pop(_key) |
|
|
|
|
|
|
|
|
|
|
|
return value |
|
|
|
|
|
def __contains__(self, key: Sequence[int]) -> bool: |
|
|
return self._find_longest_prefix_key(tuple(key)) is not None |
|
|
|
|
|
def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"): |
|
|
print("LlamaDiskCache.__setitem__: called", file=sys.stderr) |
|
|
key = tuple(key) |
|
|
if key in self.cache: |
|
|
print("LlamaDiskCache.__setitem__: delete", file=sys.stderr) |
|
|
del self.cache[key] |
|
|
self.cache[key] = value |
|
|
print("LlamaDiskCache.__setitem__: set", file=sys.stderr) |
|
|
while self.cache_size > self.capacity_bytes and len(self.cache) > 0: |
|
|
key_to_remove = next(iter(self.cache)) |
|
|
del self.cache[key_to_remove] |
|
|
print("LlamaDiskCache.__setitem__: trim", file=sys.stderr) |
|
|
|