leideng's picture
download
raw
16.3 kB
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
import torch
import triton
import triton.language as tl
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import support_triton
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
logger = logging.getLogger(__name__)
@triton.jit
def write_req_to_token_pool_triton(
req_to_token_ptr, # [max_batch, max_context_len]
req_pool_indices,
prefix_tensors,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
req_to_token_ptr_stride: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(0)
req_pool_index = tl.load(req_pool_indices + pid)
pre_len = tl.load(pre_lens + pid)
seq_len = tl.load(seq_lens + pid)
prefix_tensor = tl.load(prefix_tensors + pid).to(tl.pointer_type(tl.int64))
# write prefix
num_loop = tl.cdiv(pre_len, BLOCK_SIZE)
for i in range(num_loop):
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = offset < pre_len
value = tl.load(prefix_tensor + offset, mask=mask)
tl.store(
req_to_token_ptr + req_pool_index * req_to_token_ptr_stride + offset,
value,
mask=mask,
)
# NOTE: This can be slow for large bs
cumsum_start = tl.cast(0, tl.int64)
for i in range(pid):
cumsum_start += tl.load(extend_lens + i)
num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
for i in range(num_loop):
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = offset < (seq_len - pre_len)
value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
tl.store(
req_to_token_ptr
+ req_pool_index * req_to_token_ptr_stride
+ offset
+ pre_len,
value,
mask=mask,
)
def write_cache_indices(
out_cache_loc: torch.Tensor,
req_pool_indices_tensor: torch.Tensor,
req_pool_indices_cpu: torch.Tensor,
prefix_lens_tensor: torch.Tensor,
prefix_lens_cpu: torch.Tensor,
seq_lens_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
extend_lens_tensor: torch.Tensor,
extend_lens_cpu: torch.Tensor,
prefix_tensors: list[torch.Tensor],
req_to_token_pool: ReqToTokenPool,
):
if support_triton(get_global_server_args().attention_backend):
prefix_pointers = torch.tensor(
[t.data_ptr() for t in prefix_tensors],
device=req_to_token_pool.device,
)
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
write_req_to_token_pool_triton[(req_pool_indices_tensor.shape[0],)](
req_to_token_pool.req_to_token,
req_pool_indices_tensor,
prefix_pointers,
prefix_lens_tensor,
seq_lens_tensor,
extend_lens_tensor,
out_cache_loc,
req_to_token_pool.req_to_token.shape[1],
)
else:
pt = 0
for i in range(req_pool_indices_cpu.shape[0]):
req_idx = req_pool_indices_cpu[i].item()
prefix_len = prefix_lens_cpu[i].item()
seq_len = seq_lens_cpu[i].item()
extend_len = extend_lens_cpu[i].item()
req_to_token_pool.write(
(req_idx, slice(0, prefix_len)),
prefix_tensors[i],
)
req_to_token_pool.write(
(req_idx, slice(prefix_len, seq_len)),
out_cache_loc[pt : pt + extend_len],
)
pt += extend_len
def get_last_loc(
req_to_token: torch.Tensor,
req_pool_indices_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
if (
get_global_server_args().attention_backend != "ascend"
and get_global_server_args().attention_backend != "torch_native"
):
impl = get_last_loc_triton
else:
impl = get_last_loc_torch
return impl(req_to_token, req_pool_indices_tensor, prefix_lens_tensor)
def get_last_loc_torch(
req_to_token: torch.Tensor,
req_pool_indices_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
return torch.where(
prefix_lens_tensor > 0,
req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
torch.full_like(prefix_lens_tensor, -1),
)
@triton.jit
def get_last_loc_kernel(
req_to_token,
req_pool_indices_tensor,
prefix_lens_tensor,
result,
num_tokens,
req_to_token_stride,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
mask = offset < num_tokens
prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0)
req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0)
token_mask = prefix_lens > 0
token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1)
tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1)
tl.store(result + offset, tokens, mask=mask)
def get_last_loc_triton(
req_to_token: torch.Tensor,
req_pool_indices_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
BLOCK_SIZE = 256
num_tokens = prefix_lens_tensor.shape[0]
result = torch.empty_like(prefix_lens_tensor)
grid = (triton.cdiv(num_tokens, BLOCK_SIZE),)
get_last_loc_kernel[grid](
req_to_token,
req_pool_indices_tensor,
prefix_lens_tensor,
result,
num_tokens,
req_to_token.stride(0),
BLOCK_SIZE,
)
return result
def alloc_token_slots(
tree_cache: BasePrefixCache,
num_tokens: int,
backup_state: bool = False,
):
allocator = tree_cache.token_to_kv_pool_allocator
evict_from_tree_cache(tree_cache, num_tokens)
state = None
if backup_state:
state = allocator.backup_state()
out_cache_loc = allocator.alloc(num_tokens)
if out_cache_loc is None:
error_msg = (
f"Out of memory. Try to lower your batch size.\n"
f"Try to allocate {num_tokens} tokens.\n"
f"{available_and_evictable_str(tree_cache)}"
)
logger.error(error_msg)
if tree_cache is not None:
tree_cache.pretty_print()
raise RuntimeError(error_msg)
return (out_cache_loc, state) if backup_state else out_cache_loc
def evict_from_tree_cache(tree_cache: BasePrefixCache | None, num_tokens: int):
if tree_cache is None:
return
if isinstance(tree_cache, (SWAChunkCache, ChunkCache)):
return
allocator = tree_cache.token_to_kv_pool_allocator
# Check if this is a hybrid allocator
if hasattr(allocator, "full_available_size"):
# Hybrid allocator
full_available_size = allocator.full_available_size()
swa_available_size = allocator.swa_available_size()
if full_available_size < num_tokens or swa_available_size < num_tokens:
full_num_tokens = max(0, num_tokens - full_available_size)
swa_num_tokens = max(0, num_tokens - swa_available_size)
tree_cache.evict(full_num_tokens, swa_num_tokens)
else:
# Standard allocator
if allocator.available_size() < num_tokens:
tree_cache.evict(num_tokens)
def alloc_paged_token_slots_extend(
tree_cache: BasePrefixCache,
prefix_lens: torch.Tensor,
prefix_lens_cpu: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_cpu: torch.Tensor,
last_loc: torch.Tensor,
extend_num_tokens: int,
backup_state: bool = False,
):
# Over estimate the number of tokens: assume each request needs a new page.
allocator = tree_cache.token_to_kv_pool_allocator
num_tokens = extend_num_tokens + len(seq_lens_cpu) * allocator.page_size
evict_from_tree_cache(tree_cache, num_tokens)
state = None
if backup_state:
state = allocator.backup_state()
out_cache_loc = allocator.alloc_extend(
prefix_lens,
prefix_lens_cpu,
seq_lens,
seq_lens_cpu,
last_loc,
extend_num_tokens,
)
if out_cache_loc is None:
error_msg = (
f"Prefill out of memory. Try to lower your batch size.\n"
f"Try to allocate {extend_num_tokens} tokens.\n"
f"{available_and_evictable_str(tree_cache)}"
)
logger.error(error_msg)
if tree_cache is not None:
tree_cache.pretty_print()
raise RuntimeError(error_msg)
return (out_cache_loc, state) if backup_state else out_cache_loc
def alloc_req_slots(
req_to_token_pool: ReqToTokenPool,
num_reqs: int,
reqs: list[Req] | None,
tree_cache: BasePrefixCache | None,
) -> list[int]:
"""Allocate request slots from the pool."""
if isinstance(req_to_token_pool, HybridReqToTokenPool):
mamba_available_size = req_to_token_pool.mamba_pool.available_size()
if mamba_available_size < num_reqs:
if tree_cache is not None and isinstance(tree_cache, MambaRadixCache):
mamba_num = max(0, num_reqs - mamba_available_size)
tree_cache.evict_mamba(mamba_num)
req_pool_indices = req_to_token_pool.alloc(num_reqs, reqs)
else:
req_pool_indices = req_to_token_pool.alloc(num_reqs)
if req_pool_indices is None:
raise RuntimeError(
"alloc_req_slots runs out of memory. "
"Please set a smaller number for `--max-running-requests`. "
f"{req_to_token_pool.available_size()=}, "
f"{num_reqs=}, "
)
return req_pool_indices
def alloc_for_extend(
batch: ScheduleBatch,
) -> tuple[torch.Tensor, torch.Tensor, list[int]]:
"""
Allocate KV cache for extend batch and write to req_to_token_pool.
Returns:
out_cache_loc: allocated cache locations
req_pool_indices_device: request pool indices at a device tensor
req_pool_indices: request pool indices as list
"""
# free out-of-window swa tokens
if isinstance(batch.tree_cache, SWAChunkCache):
for req, pre_len in zip(batch.reqs, batch.prefix_lens):
batch.tree_cache.evict_swa(
req, pre_len, batch.model_config.attention_chunk_size
)
bs = len(batch.reqs)
prefix_tensors = [r.prefix_indices for r in batch.reqs]
# Create tensors for allocation
prefix_lens_cpu = torch.tensor(batch.prefix_lens, dtype=torch.int64)
extend_lens_cpu = torch.tensor(batch.extend_lens, dtype=torch.int64)
prefix_lens_device = prefix_lens_cpu.to(batch.device, non_blocking=True)
extend_lens_device = extend_lens_cpu.to(batch.device, non_blocking=True)
# Allocate req slots
req_pool_indices = alloc_req_slots(
batch.req_to_token_pool, bs, batch.reqs, batch.tree_cache
)
req_pool_indices_cpu = torch.tensor(req_pool_indices, dtype=torch.int64)
req_pool_indices_device = req_pool_indices_cpu.to(batch.device, non_blocking=True)
# Allocate KV cache (throws exception on failure)
if batch.tree_cache.page_size == 1:
out_cache_loc = alloc_token_slots(batch.tree_cache, batch.extend_num_tokens)
else:
# Paged allocation - build last_loc
last_loc = [
(t[-1:] if len(t) > 0 else torch.tensor([-1], device=batch.device))
for t in prefix_tensors
]
out_cache_loc = alloc_paged_token_slots_extend(
tree_cache=batch.tree_cache,
prefix_lens=prefix_lens_device,
prefix_lens_cpu=prefix_lens_cpu,
seq_lens=batch.seq_lens,
seq_lens_cpu=batch.seq_lens_cpu,
last_loc=torch.cat(last_loc),
extend_num_tokens=batch.extend_num_tokens,
)
# Write to req_to_token_pool
write_cache_indices(
out_cache_loc,
req_pool_indices_device,
req_pool_indices_cpu,
prefix_lens_device,
prefix_lens_cpu,
batch.seq_lens,
batch.seq_lens_cpu,
extend_lens_device,
extend_lens_cpu,
prefix_tensors,
batch.req_to_token_pool,
)
return out_cache_loc, req_pool_indices_device, req_pool_indices
def alloc_paged_token_slots_decode(
tree_cache: BasePrefixCache,
seq_lens: torch.Tensor,
seq_lens_cpu: torch.Tensor,
last_loc: torch.Tensor,
token_per_req: int = 1,
) -> torch.Tensor:
"""Allocate paged KV cache for decode batch."""
allocator = tree_cache.token_to_kv_pool_allocator
# Over estimate the number of tokens: assume each request needs a new page.
num_tokens = len(seq_lens) * allocator.page_size
evict_from_tree_cache(tree_cache, num_tokens)
out_cache_loc = allocator.alloc_decode(seq_lens, seq_lens_cpu, last_loc)
if out_cache_loc is None:
error_msg = (
f"Decode out of memory. Try to lower your batch size.\n"
f"Try to allocate {len(seq_lens) * token_per_req} tokens.\n"
f"{available_and_evictable_str(tree_cache)}"
)
logger.error(error_msg)
if tree_cache is not None:
tree_cache.pretty_print()
raise RuntimeError(error_msg)
return out_cache_loc
def alloc_for_decode(batch: ScheduleBatch, token_per_req: int) -> torch.Tensor:
"""
Allocate KV cache for decode batch and write to req_to_token_pool.
Returns:
out_cache_loc: allocated cache locations
"""
if isinstance(batch.tree_cache, SWAChunkCache):
for req in batch.reqs:
batch.tree_cache.evict_swa(
req, req.seqlen - 1, batch.model_config.attention_chunk_size
)
bs = batch.seq_lens.shape[0]
if batch.tree_cache.page_size == 1:
# Non-paged allocation
out_cache_loc = alloc_token_slots(batch.tree_cache, bs * token_per_req)
else:
# Paged allocation
last_loc = batch.req_to_token_pool.req_to_token[
batch.req_pool_indices, batch.seq_lens - 1
]
seq_lens_next = batch.seq_lens + token_per_req
out_cache_loc = alloc_paged_token_slots_decode(
tree_cache=batch.tree_cache,
seq_lens=seq_lens_next,
seq_lens_cpu=batch.seq_lens_cpu + token_per_req,
last_loc=last_loc,
token_per_req=token_per_req,
)
# Write to req_to_token_pool
if batch.model_config.is_encoder_decoder:
locs = batch.encoder_lens + batch.seq_lens
else:
locs = batch.seq_lens.clone()
batch.req_to_token_pool.write(
(batch.req_pool_indices, locs), out_cache_loc.to(torch.int32)
)
return out_cache_loc
def available_and_evictable_str(tree_cache) -> str:
token_to_kv_pool_allocator = tree_cache.token_to_kv_pool_allocator
if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
full_available_size = token_to_kv_pool_allocator.full_available_size()
swa_available_size = token_to_kv_pool_allocator.swa_available_size()
full_evictable_size = tree_cache.full_evictable_size()
swa_evictable_size = tree_cache.swa_evictable_size()
return (
f"Available full tokens: {full_available_size + full_evictable_size} ({full_available_size=} + {full_evictable_size=})\n"
f"Available swa tokens: {swa_available_size + swa_evictable_size} ({swa_available_size=} + {swa_evictable_size=})\n"
f"Full LRU list evictable size: {tree_cache.full_lru_list_evictable_size()}\n"
f"SWA LRU list evictable size: {tree_cache.swa_lru_list_evictable_size()}\n"
)
else:
available_size = token_to_kv_pool_allocator.available_size()
evictable_size = tree_cache.evictable_size()
return f"Available tokens: {available_size + evictable_size} ({available_size=} + {evictable_size=})\n"

Xet Storage Details

Size:
16.3 kB
·
Xet hash:
e958e9f900507728640ab18fb87be5fc22ff5a483e1a391953c586e517d49a3f

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.