| import logging | |
| import threading | |
| import time | |
| import torch | |
| from sglang.srt.managers.cache_controller import HiCacheController | |
| from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator | |
| from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache | |
| from sglang.srt.mem_cache.memory_pool import ( | |
| MHATokenToKVPool, | |
| MLATokenToKVPool, | |
| ReqToTokenPool, | |
| ) | |
| from sglang.srt.mem_cache.memory_pool_host import ( | |
| MHATokenToKVPoolHost, | |
| MLATokenToKVPoolHost, | |
| ) | |
| from sglang.srt.server_args import ServerArgs | |
| logger = logging.getLogger(__name__) | |
| class DecodeKVCacheOffloadManager: | |
| """Manage decode-side KV cache offloading lifecycle and operations.""" | |
| def __init__( | |
| self, | |
| req_to_token_pool: ReqToTokenPool, | |
| token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator, | |
| tp_group: torch.distributed.ProcessGroup, | |
| tree_cache: BasePrefixCache, | |
| server_args: ServerArgs, | |
| ) -> None: | |
| self.req_to_token_pool = req_to_token_pool | |
| self.token_to_kv_pool_allocator = token_to_kv_pool_allocator | |
| self.page_size = server_args.page_size | |
| self.server_args = server_args | |
| self.request_counter = 0 | |
| self.tree_cache = tree_cache | |
| kv_cache = self.token_to_kv_pool_allocator.get_kvcache() | |
| if isinstance(kv_cache, MHATokenToKVPool): | |
| self.decode_host_mem_pool = MHATokenToKVPoolHost( | |
| kv_cache, | |
| server_args.hicache_ratio, | |
| server_args.hicache_size, | |
| self.page_size, | |
| server_args.hicache_mem_layout, | |
| ) | |
| elif isinstance(kv_cache, MLATokenToKVPool): | |
| self.decode_host_mem_pool = MLATokenToKVPoolHost( | |
| kv_cache, | |
| server_args.hicache_ratio, | |
| server_args.hicache_size, | |
| self.page_size, | |
| server_args.hicache_mem_layout, | |
| ) | |
| else: | |
| raise ValueError("Unsupported KV cache type for decode offload") | |
| self.tp_group = tp_group | |
| self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group) | |
| self.cache_controller = HiCacheController( | |
| token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, | |
| mem_pool_host=self.decode_host_mem_pool, | |
| page_size=self.page_size, | |
| tp_group=tp_group, | |
| io_backend=server_args.hicache_io_backend, | |
| load_cache_event=threading.Event(), | |
| storage_backend=server_args.hicache_storage_backend, | |
| model_name=server_args.served_model_name, | |
| storage_backend_extra_config=server_args.hicache_storage_backend_extra_config, | |
| ) | |
| self.ongoing_offload = {} | |
| self.ongoing_backup = {} | |
| logger.info("Enable offload kv cache for decode side") | |
| def offload_kv_cache(self, req) -> bool: | |
| """Offload a finished request's KV cache to storage.""" | |
| if self.cache_controller is None or self.decode_host_mem_pool is None: | |
| return False | |
| if req.req_pool_idx == -1: | |
| return False | |
| token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx] | |
| if token_indices.dim() == 0 or token_indices.numel() == 0: | |
| logger.debug( | |
| f"Request {req.rid} has invalid token_indices: {token_indices}" | |
| ) | |
| return False | |
| tokens = req.origin_input_ids + req.output_ids | |
| aligned_len = (len(tokens) // self.page_size) * self.page_size | |
| if aligned_len == 0: | |
| return False | |
| token_indices = token_indices[:aligned_len] | |
| tokens = tokens[:aligned_len] | |
| # Asynchronously offload KV cache from device to host by cache controller | |
| self.request_counter += 1 | |
| ack_id = self.request_counter | |
| host_indices = self.cache_controller.write( | |
| device_indices=token_indices.long(), | |
| node_id=ack_id, | |
| ) | |
| if host_indices is None: | |
| logger.error(f"Not enough host memory for request {req.rid}") | |
| return False | |
| self.ongoing_offload[ack_id] = (req, host_indices, tokens, time.time()) | |
| return True | |
| def check_offload_progress(self): | |
| """Check the progress of offload from device to host and backup from host to storage.""" | |
| cc = self.cache_controller | |
| qsizes = torch.tensor( | |
| [ | |
| len(cc.ack_write_queue), | |
| cc.ack_backup_queue.qsize(), | |
| ], | |
| dtype=torch.int, | |
| ) | |
| if self.tp_world_size > 1: | |
| torch.distributed.all_reduce( | |
| qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group | |
| ) | |
| n_write, n_backup = map(int, qsizes.tolist()) | |
| self._check_offload_progress(n_write) | |
| self._check_backup_progress(n_backup) | |
| def _check_offload_progress(self, finish_count): | |
| """Check the progress of offload from device to host.""" | |
| while finish_count > 0: | |
| _, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0) | |
| finish_event.synchronize() | |
| for ack_id in ack_list: | |
| req, host_indices, tokens, start_time = self.ongoing_offload.pop(ack_id) | |
| # Release device | |
| self.tree_cache.cache_finished_req(req) | |
| # Trigger async backup from host to storage by cache controller | |
| self._trigger_backup(req.rid, host_indices, tokens, start_time) | |
| finish_count -= 1 | |
| def _check_backup_progress(self, finish_count): | |
| """Check the progress of backup from host to storage.""" | |
| for _ in range(finish_count): | |
| storage_operation = self.cache_controller.ack_backup_queue.get() | |
| ack_id = storage_operation.id | |
| req_id, host_indices, start_time = self.ongoing_backup.pop(ack_id) | |
| # Release host memory | |
| self.decode_host_mem_pool.free(host_indices) | |
| logger.debug( | |
| f"Finished backup request {req_id}, free host memory, len:{len(host_indices)}, cost time:{time.time() - start_time:.2f} seconds." | |
| ) | |
| def _trigger_backup(self, req_id, host_indices, tokens, start_time): | |
| """Trigger async backup from host to storage by cache controller.""" | |
| # Generate page hashes and write to storage | |
| page_hashes = self._compute_prefix_hash(tokens) | |
| ack_id = self.cache_controller.write_storage( | |
| host_indices, | |
| tokens, | |
| hash_value=page_hashes, | |
| ) | |
| self.ongoing_backup[ack_id] = (req_id, host_indices, start_time) | |
| def _compute_prefix_hash(self, tokens): | |
| last_hash = "" | |
| page_hashes = [] | |
| for offset in range(0, len(tokens), self.page_size): | |
| page_tokens = tokens[offset : offset + self.page_size] | |
| last_hash = self.cache_controller.get_hash_str(page_tokens, last_hash) | |
| page_hashes.append(last_hash) | |
| return page_hashes | |
Xet Storage Details
- Size:
- 7.05 kB
- Xet hash:
- 7d18960c36ebe171ec8e76cbfbed775b9b2f9fb2c3fccdf5f05ad368b694a478
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.