| from __future__ import annotations | |
| import logging | |
| from typing import TYPE_CHECKING | |
| import torch | |
| import triton | |
| import triton.language as tl | |
| from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator | |
| from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache | |
| from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache | |
| from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache | |
| from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool | |
| from sglang.srt.server_args import get_global_server_args | |
| from sglang.srt.utils import support_triton | |
| if TYPE_CHECKING: | |
| from sglang.srt.managers.schedule_batch import Req, ScheduleBatch | |
| logger = logging.getLogger(__name__) | |
| def write_req_to_token_pool_triton( | |
| req_to_token_ptr, # [max_batch, max_context_len] | |
| req_pool_indices, | |
| prefix_tensors, | |
| pre_lens, | |
| seq_lens, | |
| extend_lens, | |
| out_cache_loc, | |
| req_to_token_ptr_stride: tl.constexpr, | |
| ): | |
| BLOCK_SIZE: tl.constexpr = 512 | |
| pid = tl.program_id(0) | |
| req_pool_index = tl.load(req_pool_indices + pid) | |
| pre_len = tl.load(pre_lens + pid) | |
| seq_len = tl.load(seq_lens + pid) | |
| prefix_tensor = tl.load(prefix_tensors + pid).to(tl.pointer_type(tl.int64)) | |
| # write prefix | |
| num_loop = tl.cdiv(pre_len, BLOCK_SIZE) | |
| for i in range(num_loop): | |
| offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE | |
| mask = offset < pre_len | |
| value = tl.load(prefix_tensor + offset, mask=mask) | |
| tl.store( | |
| req_to_token_ptr + req_pool_index * req_to_token_ptr_stride + offset, | |
| value, | |
| mask=mask, | |
| ) | |
| # NOTE: This can be slow for large bs | |
| cumsum_start = tl.cast(0, tl.int64) | |
| for i in range(pid): | |
| cumsum_start += tl.load(extend_lens + i) | |
| num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE) | |
| for i in range(num_loop): | |
| offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE | |
| mask = offset < (seq_len - pre_len) | |
| value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask) | |
| tl.store( | |
| req_to_token_ptr | |
| + req_pool_index * req_to_token_ptr_stride | |
| + offset | |
| + pre_len, | |
| value, | |
| mask=mask, | |
| ) | |
| def write_cache_indices( | |
| out_cache_loc: torch.Tensor, | |
| req_pool_indices_tensor: torch.Tensor, | |
| req_pool_indices_cpu: torch.Tensor, | |
| prefix_lens_tensor: torch.Tensor, | |
| prefix_lens_cpu: torch.Tensor, | |
| seq_lens_tensor: torch.Tensor, | |
| seq_lens_cpu: torch.Tensor, | |
| extend_lens_tensor: torch.Tensor, | |
| extend_lens_cpu: torch.Tensor, | |
| prefix_tensors: list[torch.Tensor], | |
| req_to_token_pool: ReqToTokenPool, | |
| ): | |
| if support_triton(get_global_server_args().attention_backend): | |
| prefix_pointers = torch.tensor( | |
| [t.data_ptr() for t in prefix_tensors], | |
| device=req_to_token_pool.device, | |
| ) | |
| # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start) | |
| write_req_to_token_pool_triton[(req_pool_indices_tensor.shape[0],)]( | |
| req_to_token_pool.req_to_token, | |
| req_pool_indices_tensor, | |
| prefix_pointers, | |
| prefix_lens_tensor, | |
| seq_lens_tensor, | |
| extend_lens_tensor, | |
| out_cache_loc, | |
| req_to_token_pool.req_to_token.shape[1], | |
| ) | |
| else: | |
| pt = 0 | |
| for i in range(req_pool_indices_cpu.shape[0]): | |
| req_idx = req_pool_indices_cpu[i].item() | |
| prefix_len = prefix_lens_cpu[i].item() | |
| seq_len = seq_lens_cpu[i].item() | |
| extend_len = extend_lens_cpu[i].item() | |
| req_to_token_pool.write( | |
| (req_idx, slice(0, prefix_len)), | |
| prefix_tensors[i], | |
| ) | |
| req_to_token_pool.write( | |
| (req_idx, slice(prefix_len, seq_len)), | |
| out_cache_loc[pt : pt + extend_len], | |
| ) | |
| pt += extend_len | |
| def get_last_loc( | |
| req_to_token: torch.Tensor, | |
| req_pool_indices_tensor: torch.Tensor, | |
| prefix_lens_tensor: torch.Tensor, | |
| ) -> torch.Tensor: | |
| if ( | |
| get_global_server_args().attention_backend != "ascend" | |
| and get_global_server_args().attention_backend != "torch_native" | |
| ): | |
| impl = get_last_loc_triton | |
| else: | |
| impl = get_last_loc_torch | |
| return impl(req_to_token, req_pool_indices_tensor, prefix_lens_tensor) | |
| def get_last_loc_torch( | |
| req_to_token: torch.Tensor, | |
| req_pool_indices_tensor: torch.Tensor, | |
| prefix_lens_tensor: torch.Tensor, | |
| ) -> torch.Tensor: | |
| return torch.where( | |
| prefix_lens_tensor > 0, | |
| req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1], | |
| torch.full_like(prefix_lens_tensor, -1), | |
| ) | |
| def get_last_loc_kernel( | |
| req_to_token, | |
| req_pool_indices_tensor, | |
| prefix_lens_tensor, | |
| result, | |
| num_tokens, | |
| req_to_token_stride, | |
| BLOCK_SIZE: tl.constexpr, | |
| ): | |
| pid = tl.program_id(0) | |
| offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE | |
| mask = offset < num_tokens | |
| prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0) | |
| req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0) | |
| token_mask = prefix_lens > 0 | |
| token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1) | |
| tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1) | |
| tl.store(result + offset, tokens, mask=mask) | |
| def get_last_loc_triton( | |
| req_to_token: torch.Tensor, | |
| req_pool_indices_tensor: torch.Tensor, | |
| prefix_lens_tensor: torch.Tensor, | |
| ) -> torch.Tensor: | |
| BLOCK_SIZE = 256 | |
| num_tokens = prefix_lens_tensor.shape[0] | |
| result = torch.empty_like(prefix_lens_tensor) | |
| grid = (triton.cdiv(num_tokens, BLOCK_SIZE),) | |
| get_last_loc_kernel[grid]( | |
| req_to_token, | |
| req_pool_indices_tensor, | |
| prefix_lens_tensor, | |
| result, | |
| num_tokens, | |
| req_to_token.stride(0), | |
| BLOCK_SIZE, | |
| ) | |
| return result | |
| def alloc_token_slots( | |
| tree_cache: BasePrefixCache, | |
| num_tokens: int, | |
| backup_state: bool = False, | |
| ): | |
| allocator = tree_cache.token_to_kv_pool_allocator | |
| evict_from_tree_cache(tree_cache, num_tokens) | |
| state = None | |
| if backup_state: | |
| state = allocator.backup_state() | |
| out_cache_loc = allocator.alloc(num_tokens) | |
| if out_cache_loc is None: | |
| error_msg = ( | |
| f"Out of memory. Try to lower your batch size.\n" | |
| f"Try to allocate {num_tokens} tokens.\n" | |
| f"{available_and_evictable_str(tree_cache)}" | |
| ) | |
| logger.error(error_msg) | |
| if tree_cache is not None: | |
| tree_cache.pretty_print() | |
| raise RuntimeError(error_msg) | |
| return (out_cache_loc, state) if backup_state else out_cache_loc | |
| def evict_from_tree_cache(tree_cache: BasePrefixCache | None, num_tokens: int): | |
| if tree_cache is None: | |
| return | |
| if isinstance(tree_cache, (SWAChunkCache, ChunkCache)): | |
| return | |
| allocator = tree_cache.token_to_kv_pool_allocator | |
| # Check if this is a hybrid allocator | |
| if hasattr(allocator, "full_available_size"): | |
| # Hybrid allocator | |
| full_available_size = allocator.full_available_size() | |
| swa_available_size = allocator.swa_available_size() | |
| if full_available_size < num_tokens or swa_available_size < num_tokens: | |
| full_num_tokens = max(0, num_tokens - full_available_size) | |
| swa_num_tokens = max(0, num_tokens - swa_available_size) | |
| tree_cache.evict(full_num_tokens, swa_num_tokens) | |
| else: | |
| # Standard allocator | |
| if allocator.available_size() < num_tokens: | |
| tree_cache.evict(num_tokens) | |
| def alloc_paged_token_slots_extend( | |
| tree_cache: BasePrefixCache, | |
| prefix_lens: torch.Tensor, | |
| prefix_lens_cpu: torch.Tensor, | |
| seq_lens: torch.Tensor, | |
| seq_lens_cpu: torch.Tensor, | |
| last_loc: torch.Tensor, | |
| extend_num_tokens: int, | |
| backup_state: bool = False, | |
| ): | |
| # Over estimate the number of tokens: assume each request needs a new page. | |
| allocator = tree_cache.token_to_kv_pool_allocator | |
| num_tokens = extend_num_tokens + len(seq_lens_cpu) * allocator.page_size | |
| evict_from_tree_cache(tree_cache, num_tokens) | |
| state = None | |
| if backup_state: | |
| state = allocator.backup_state() | |
| out_cache_loc = allocator.alloc_extend( | |
| prefix_lens, | |
| prefix_lens_cpu, | |
| seq_lens, | |
| seq_lens_cpu, | |
| last_loc, | |
| extend_num_tokens, | |
| ) | |
| if out_cache_loc is None: | |
| error_msg = ( | |
| f"Prefill out of memory. Try to lower your batch size.\n" | |
| f"Try to allocate {extend_num_tokens} tokens.\n" | |
| f"{available_and_evictable_str(tree_cache)}" | |
| ) | |
| logger.error(error_msg) | |
| if tree_cache is not None: | |
| tree_cache.pretty_print() | |
| raise RuntimeError(error_msg) | |
| return (out_cache_loc, state) if backup_state else out_cache_loc | |
| def alloc_req_slots( | |
| req_to_token_pool: ReqToTokenPool, | |
| num_reqs: int, | |
| reqs: list[Req] | None, | |
| tree_cache: BasePrefixCache | None, | |
| ) -> list[int]: | |
| """Allocate request slots from the pool.""" | |
| if isinstance(req_to_token_pool, HybridReqToTokenPool): | |
| mamba_available_size = req_to_token_pool.mamba_pool.available_size() | |
| if mamba_available_size < num_reqs: | |
| if tree_cache is not None and isinstance(tree_cache, MambaRadixCache): | |
| mamba_num = max(0, num_reqs - mamba_available_size) | |
| tree_cache.evict_mamba(mamba_num) | |
| req_pool_indices = req_to_token_pool.alloc(num_reqs, reqs) | |
| else: | |
| req_pool_indices = req_to_token_pool.alloc(num_reqs) | |
| if req_pool_indices is None: | |
| raise RuntimeError( | |
| "alloc_req_slots runs out of memory. " | |
| "Please set a smaller number for `--max-running-requests`. " | |
| f"{req_to_token_pool.available_size()=}, " | |
| f"{num_reqs=}, " | |
| ) | |
| return req_pool_indices | |
| def alloc_for_extend( | |
| batch: ScheduleBatch, | |
| ) -> tuple[torch.Tensor, torch.Tensor, list[int]]: | |
| """ | |
| Allocate KV cache for extend batch and write to req_to_token_pool. | |
| Returns: | |
| out_cache_loc: allocated cache locations | |
| req_pool_indices_device: request pool indices at a device tensor | |
| req_pool_indices: request pool indices as list | |
| """ | |
| # free out-of-window swa tokens | |
| if isinstance(batch.tree_cache, SWAChunkCache): | |
| for req, pre_len in zip(batch.reqs, batch.prefix_lens): | |
| batch.tree_cache.evict_swa( | |
| req, pre_len, batch.model_config.attention_chunk_size | |
| ) | |
| bs = len(batch.reqs) | |
| prefix_tensors = [r.prefix_indices for r in batch.reqs] | |
| # Create tensors for allocation | |
| prefix_lens_cpu = torch.tensor(batch.prefix_lens, dtype=torch.int64) | |
| extend_lens_cpu = torch.tensor(batch.extend_lens, dtype=torch.int64) | |
| prefix_lens_device = prefix_lens_cpu.to(batch.device, non_blocking=True) | |
| extend_lens_device = extend_lens_cpu.to(batch.device, non_blocking=True) | |
| # Allocate req slots | |
| req_pool_indices = alloc_req_slots( | |
| batch.req_to_token_pool, bs, batch.reqs, batch.tree_cache | |
| ) | |
| req_pool_indices_cpu = torch.tensor(req_pool_indices, dtype=torch.int64) | |
| req_pool_indices_device = req_pool_indices_cpu.to(batch.device, non_blocking=True) | |
| # Allocate KV cache (throws exception on failure) | |
| if batch.tree_cache.page_size == 1: | |
| out_cache_loc = alloc_token_slots(batch.tree_cache, batch.extend_num_tokens) | |
| else: | |
| # Paged allocation - build last_loc | |
| last_loc = [ | |
| (t[-1:] if len(t) > 0 else torch.tensor([-1], device=batch.device)) | |
| for t in prefix_tensors | |
| ] | |
| out_cache_loc = alloc_paged_token_slots_extend( | |
| tree_cache=batch.tree_cache, | |
| prefix_lens=prefix_lens_device, | |
| prefix_lens_cpu=prefix_lens_cpu, | |
| seq_lens=batch.seq_lens, | |
| seq_lens_cpu=batch.seq_lens_cpu, | |
| last_loc=torch.cat(last_loc), | |
| extend_num_tokens=batch.extend_num_tokens, | |
| ) | |
| # Write to req_to_token_pool | |
| write_cache_indices( | |
| out_cache_loc, | |
| req_pool_indices_device, | |
| req_pool_indices_cpu, | |
| prefix_lens_device, | |
| prefix_lens_cpu, | |
| batch.seq_lens, | |
| batch.seq_lens_cpu, | |
| extend_lens_device, | |
| extend_lens_cpu, | |
| prefix_tensors, | |
| batch.req_to_token_pool, | |
| ) | |
| return out_cache_loc, req_pool_indices_device, req_pool_indices | |
| def alloc_paged_token_slots_decode( | |
| tree_cache: BasePrefixCache, | |
| seq_lens: torch.Tensor, | |
| seq_lens_cpu: torch.Tensor, | |
| last_loc: torch.Tensor, | |
| token_per_req: int = 1, | |
| ) -> torch.Tensor: | |
| """Allocate paged KV cache for decode batch.""" | |
| allocator = tree_cache.token_to_kv_pool_allocator | |
| # Over estimate the number of tokens: assume each request needs a new page. | |
| num_tokens = len(seq_lens) * allocator.page_size | |
| evict_from_tree_cache(tree_cache, num_tokens) | |
| out_cache_loc = allocator.alloc_decode(seq_lens, seq_lens_cpu, last_loc) | |
| if out_cache_loc is None: | |
| error_msg = ( | |
| f"Decode out of memory. Try to lower your batch size.\n" | |
| f"Try to allocate {len(seq_lens) * token_per_req} tokens.\n" | |
| f"{available_and_evictable_str(tree_cache)}" | |
| ) | |
| logger.error(error_msg) | |
| if tree_cache is not None: | |
| tree_cache.pretty_print() | |
| raise RuntimeError(error_msg) | |
| return out_cache_loc | |
| def alloc_for_decode(batch: ScheduleBatch, token_per_req: int) -> torch.Tensor: | |
| """ | |
| Allocate KV cache for decode batch and write to req_to_token_pool. | |
| Returns: | |
| out_cache_loc: allocated cache locations | |
| """ | |
| if isinstance(batch.tree_cache, SWAChunkCache): | |
| for req in batch.reqs: | |
| batch.tree_cache.evict_swa( | |
| req, req.seqlen - 1, batch.model_config.attention_chunk_size | |
| ) | |
| bs = batch.seq_lens.shape[0] | |
| if batch.tree_cache.page_size == 1: | |
| # Non-paged allocation | |
| out_cache_loc = alloc_token_slots(batch.tree_cache, bs * token_per_req) | |
| else: | |
| # Paged allocation | |
| last_loc = batch.req_to_token_pool.req_to_token[ | |
| batch.req_pool_indices, batch.seq_lens - 1 | |
| ] | |
| seq_lens_next = batch.seq_lens + token_per_req | |
| out_cache_loc = alloc_paged_token_slots_decode( | |
| tree_cache=batch.tree_cache, | |
| seq_lens=seq_lens_next, | |
| seq_lens_cpu=batch.seq_lens_cpu + token_per_req, | |
| last_loc=last_loc, | |
| token_per_req=token_per_req, | |
| ) | |
| # Write to req_to_token_pool | |
| if batch.model_config.is_encoder_decoder: | |
| locs = batch.encoder_lens + batch.seq_lens | |
| else: | |
| locs = batch.seq_lens.clone() | |
| batch.req_to_token_pool.write( | |
| (batch.req_pool_indices, locs), out_cache_loc.to(torch.int32) | |
| ) | |
| return out_cache_loc | |
| def available_and_evictable_str(tree_cache) -> str: | |
| token_to_kv_pool_allocator = tree_cache.token_to_kv_pool_allocator | |
| if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator): | |
| full_available_size = token_to_kv_pool_allocator.full_available_size() | |
| swa_available_size = token_to_kv_pool_allocator.swa_available_size() | |
| full_evictable_size = tree_cache.full_evictable_size() | |
| swa_evictable_size = tree_cache.swa_evictable_size() | |
| return ( | |
| f"Available full tokens: {full_available_size + full_evictable_size} ({full_available_size=} + {full_evictable_size=})\n" | |
| f"Available swa tokens: {swa_available_size + swa_evictable_size} ({swa_available_size=} + {swa_evictable_size=})\n" | |
| f"Full LRU list evictable size: {tree_cache.full_lru_list_evictable_size()}\n" | |
| f"SWA LRU list evictable size: {tree_cache.swa_lru_list_evictable_size()}\n" | |
| ) | |
| else: | |
| available_size = token_to_kv_pool_allocator.available_size() | |
| evictable_size = tree_cache.evictable_size() | |
| return f"Available tokens: {available_size + evictable_size} ({available_size=} + {evictable_size=})\n" | |
Xet Storage Details
- Size:
- 16.3 kB
- Xet hash:
- e958e9f900507728640ab18fb87be5fc22ff5a483e1a391953c586e517d49a3f
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.