leideng/QCFuse / srt /mem_cache /base_prefix_cache.py
leideng's picture
download
raw
2.81 kB
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple
import torch
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
else:
Req = Any # Placeholder for Req type when not type checking
class MatchResult(NamedTuple):
"""Result of a prefix match operation.
Attributes:
device_indices : Indices of the KV cache on the device matched by common prefix.
last_device_node: The last TreeNode on the device that was matched.
last_host_node : The last TreeNode on the host that was matched.
Note that if HiCache is not enabled,
this **must** be the same as `last_device_node`.
host_hit_length : Length of the KV cache hit on the host, if applicable.
0 if HiCache is not enabled.
"""
device_indices: torch.Tensor
last_device_node: Any
last_host_node: Any
host_hit_length: int = 0
class BasePrefixCache(ABC):
"""Cache can be indexed by either rid or key."""
@abstractmethod
def reset(self):
pass
@abstractmethod
def match_prefix(self, key: Any, **kwargs) -> MatchResult:
pass
@abstractmethod
def cache_finished_req(self, req: Req, is_insert: bool = True, **kwargs):
pass
@abstractmethod
def cache_unfinished_req(self, req: Req, **kwargs):
pass
@abstractmethod
def evict(self, num_tokens: int):
pass
@abstractmethod
def inc_lock_ref(self, node: Any):
pass
@abstractmethod
def dec_lock_ref(self, node: Any, swa_uuid_for_lock: Optional[str] = None):
pass
def evictable_size(self):
return 0
def full_evictable_size(self):
return 0
def swa_evictable_size(self):
return 0
def protected_size(self):
return 0
def full_protected_size(self):
return 0
def swa_protected_size(self):
return 0
def total_size(self):
raise NotImplementedError()
def pretty_print(self):
raise NotImplementedError()
def init_load_back(
self,
last_host_node: Any,
host_hit_length: int,
) -> Tuple[torch.Tensor, Any]:
"""
Preparing KV cache loading from host to device.
"""
raise NotImplementedError()
def ready_to_load_host_cache(self) -> Any:
"""
Notify the cache controller to start the KV cache loading
"""
raise NotImplementedError()
def check_hicache_events(self) -> Any:
"""
Check HiCache related activities to update radix tree and synchronize across TP workers if needed
"""
raise NotImplementedError()
def take_events(self):
return []

Xet Storage Details

Size:
2.81 kB
·
Xet hash:
609c01a969d32a68bb7b4097b6809cfb5abbd9f0dae7824861aa27725da00f82

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.