| # Copyright 2023-2024 SGLang Team | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Run the model with cuda graph and torch.compile.""" | |
| from __future__ import annotations | |
| import bisect | |
| import gc | |
| import inspect | |
| import logging | |
| import os | |
| from contextlib import contextmanager | |
| from typing import TYPE_CHECKING, Callable, Optional, Union | |
| import torch | |
| import tqdm | |
| from torch.profiler import ProfilerActivity, profile | |
| from sglang.srt.custom_op import CustomOp | |
| from sglang.srt.distributed import get_tensor_model_parallel_rank | |
| from sglang.srt.distributed.device_communicators.pynccl_allocator import ( | |
| set_graph_pool_id, | |
| ) | |
| from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture | |
| from sglang.srt.layers.dp_attention import ( | |
| DpPaddingMode, | |
| get_attention_tp_rank, | |
| get_attention_tp_size, | |
| set_dp_buffer_len, | |
| set_is_extend_in_batch, | |
| ) | |
| from sglang.srt.layers.logits_processor import LogitsProcessorOutput | |
| from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPBuffer | |
| from sglang.srt.layers.moe.utils import get_deepep_mode, get_moe_a2a_backend | |
| from sglang.srt.layers.torchao_utils import save_gemlite_cache | |
| from sglang.srt.model_executor.forward_batch_info import ( | |
| CaptureHiddenMode, | |
| ForwardBatch, | |
| ForwardMode, | |
| PPProxyTensors, | |
| enable_num_token_non_padded, | |
| ) | |
| from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin | |
| from sglang.srt.utils import ( | |
| empty_context, | |
| get_available_gpu_memory, | |
| get_bool_env_var, | |
| is_hip, | |
| log_info_on_rank0, | |
| require_attn_tp_gather, | |
| require_gathered_buffer, | |
| require_mlp_sync, | |
| require_mlp_tp_gather, | |
| ) | |
| from sglang.srt.utils.patch_torch import monkey_patch_torch_compile | |
| try: | |
| from kt_kernel import AMXMoEWrapper | |
| KTRANSFORMERS_AVAILABLE = True | |
| except ImportError: | |
| KTRANSFORMERS_AVAILABLE = False | |
| _is_hip = is_hip() | |
| logger = logging.getLogger(__name__) | |
| if TYPE_CHECKING: | |
| from sglang.srt.model_executor.model_runner import ModelRunner | |
| # Detect whether the current forward pass is in capture mode | |
| is_capture_mode = False | |
| def get_is_capture_mode(): | |
| return is_capture_mode | |
| def model_capture_mode(): | |
| global is_capture_mode | |
| is_capture_mode = True | |
| yield | |
| is_capture_mode = False | |
| def freeze_gc(enable_cudagraph_gc: bool): | |
| """ | |
| Optimize garbage collection during CUDA graph capture. | |
| Clean up, then freeze all remaining objects from being included | |
| in future collections if GC is disabled during capture. | |
| """ | |
| gc.collect() | |
| should_freeze = not enable_cudagraph_gc | |
| if should_freeze: | |
| gc.freeze() | |
| try: | |
| yield | |
| finally: | |
| if should_freeze: | |
| gc.unfreeze() | |
| gc.collect() | |
| def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int): | |
| for sub in model._modules.values(): | |
| if isinstance(sub, CustomOp): | |
| if reverse: | |
| sub.leave_torch_compile() | |
| else: | |
| sub.enter_torch_compile(num_tokens=num_tokens) | |
| if isinstance(sub, torch.nn.Module): | |
| _to_torch(sub, reverse, num_tokens) | |
| def patch_model( | |
| model: torch.nn.Module, | |
| enable_compile: bool, | |
| num_tokens: int, | |
| tp_group: GroupCoordinator, | |
| ): | |
| """Patch the model to make it compatible with with torch.compile""" | |
| backup_ca_comm = None | |
| try: | |
| if enable_compile: | |
| _to_torch(model, reverse=False, num_tokens=num_tokens) | |
| backup_ca_comm = tp_group.ca_comm | |
| # Use custom-allreduce here. | |
| # We found the custom allreduce is much faster than the built-in allreduce in torch, | |
| # even with ENABLE_INTRA_NODE_COMM=1. | |
| # tp_group.ca_comm = None | |
| yield torch.compile( | |
| torch.no_grad()(model.forward), | |
| mode=os.environ.get( | |
| "SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs" | |
| ), | |
| dynamic=_is_hip and get_bool_env_var("SGLANG_TORCH_DYNAMIC_SHAPE"), | |
| ) | |
| else: | |
| yield model.forward | |
| finally: | |
| if enable_compile: | |
| _to_torch(model, reverse=True, num_tokens=num_tokens) | |
| tp_group.ca_comm = backup_ca_comm | |
| def set_torch_compile_config(): | |
| import torch._dynamo.config | |
| import torch._inductor.config | |
| torch._inductor.config.coordinate_descent_tuning = True | |
| torch._inductor.config.triton.unique_kernel_names = True | |
| torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future | |
| # FIXME: tmp workaround | |
| torch._dynamo.config.accumulated_cache_size_limit = 1024 | |
| if hasattr(torch._dynamo.config, "cache_size_limit"): | |
| torch._dynamo.config.cache_size_limit = 1024 | |
| monkey_patch_torch_compile() | |
| def get_batch_sizes_to_capture(model_runner: ModelRunner): | |
| server_args = model_runner.server_args | |
| capture_bs = server_args.cuda_graph_bs | |
| if max(capture_bs) > model_runner.req_to_token_pool.size: | |
| # In some cases (e.g., with a small GPU or --max-running-requests), the #max-running-requests | |
| # is very small. We add more values here to make sure we capture the maximum bs. | |
| capture_bs += [model_runner.req_to_token_pool.size] | |
| mul_base = 1 | |
| if server_args.enable_two_batch_overlap: | |
| mul_base *= 2 | |
| if require_gathered_buffer(server_args): | |
| mul_base *= get_attention_tp_size() | |
| capture_bs = [bs for bs in capture_bs if bs % mul_base == 0] | |
| capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size] | |
| capture_bs = list(sorted(set(capture_bs))) | |
| assert len(capture_bs) > 0 and capture_bs[0] > 0, f"{capture_bs=}" | |
| compile_bs = ( | |
| [bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs] | |
| if server_args.enable_torch_compile | |
| else [] | |
| ) | |
| return capture_bs, compile_bs | |
| # Reuse this memory pool across all cuda graph runners. | |
| global_graph_memory_pool = None | |
| def get_global_graph_memory_pool(): | |
| return global_graph_memory_pool | |
| def set_global_graph_memory_pool(val): | |
| global global_graph_memory_pool | |
| global_graph_memory_pool = val | |
| class CudaGraphRunner: | |
| """A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile.""" | |
| def __init__(self, model_runner: ModelRunner): | |
| # Parse args | |
| self.model_runner = model_runner | |
| self.device = model_runner.device | |
| self.device_module = torch.get_device_module(self.device) | |
| self.graphs = {} | |
| self.output_buffers = {} | |
| self.enable_torch_compile = model_runner.server_args.enable_torch_compile | |
| self.disable_padding = model_runner.server_args.disable_cuda_graph_padding | |
| self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder | |
| self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args) | |
| self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args) | |
| self.require_mlp_sync = require_mlp_sync(model_runner.server_args) | |
| self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args) | |
| self.enable_two_batch_overlap = ( | |
| model_runner.server_args.enable_two_batch_overlap | |
| ) | |
| self.speculative_algorithm = model_runner.server_args.speculative_algorithm | |
| self.enable_profile_cuda_graph = ( | |
| model_runner.server_args.enable_profile_cuda_graph | |
| ) | |
| self.tp_size = model_runner.server_args.tp_size | |
| self.dp_size = model_runner.server_args.dp_size | |
| self.pp_size = model_runner.server_args.pp_size | |
| self.attn_tp_size = get_attention_tp_size() | |
| self.attn_tp_rank = get_attention_tp_rank() | |
| self.deepep_adapter = DeepEPCudaGraphRunnerAdapter() | |
| # Batch sizes to capture | |
| self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) | |
| log_info_on_rank0(logger, f"Capture cuda graph bs {self.capture_bs}") | |
| if KTRANSFORMERS_AVAILABLE: | |
| AMXMoEWrapper.set_capture_batch_sizes(self.capture_bs) | |
| self.capture_forward_mode = ForwardMode.DECODE | |
| self.capture_hidden_mode = CaptureHiddenMode.NULL | |
| self.num_tokens_per_bs = 1 | |
| if ( | |
| model_runner.spec_algorithm.is_eagle() | |
| or model_runner.spec_algorithm.is_standalone() | |
| or model_runner.spec_algorithm.is_ngram() | |
| ): | |
| if self.model_runner.is_draft_worker: | |
| raise RuntimeError("This should not happen") | |
| else: | |
| self.capture_forward_mode = ForwardMode.TARGET_VERIFY | |
| self.num_tokens_per_bs = ( | |
| self.model_runner.server_args.speculative_num_draft_tokens | |
| ) | |
| # If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup | |
| if model_runner.server_args.enable_return_hidden_states: | |
| self.capture_hidden_mode = CaptureHiddenMode.FULL | |
| # Attention backend | |
| self.max_bs = max(self.capture_bs) | |
| self.max_num_token = self.max_bs * self.num_tokens_per_bs | |
| self.model_runner.attn_backend.init_cuda_graph_state( | |
| self.max_bs, self.max_num_token | |
| ) | |
| self.seq_len_fill_value = ( | |
| self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() | |
| ) | |
| self.encoder_len_fill_value = 0 | |
| self.seq_lens_cpu = torch.full( | |
| (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 | |
| ) | |
| if self.enable_torch_compile: | |
| set_torch_compile_config() | |
| if self.model_runner.server_args.enable_lora: | |
| self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs) | |
| # Graph inputs | |
| with torch.device(self.device): | |
| self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) | |
| self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32) | |
| self.seq_lens = torch.full( | |
| (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 | |
| ) | |
| self.out_cache_loc = torch.zeros( | |
| (self.max_num_token,), dtype=self._cache_loc_dtype() | |
| ) | |
| self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) | |
| self.mrope_positions = torch.zeros( | |
| (3, self.max_num_token), dtype=torch.int64 | |
| ) | |
| self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32) | |
| self.tbo_plugin = TboCudaGraphRunnerPlugin() | |
| # pipeline parallelism | |
| if self.pp_size > 1: | |
| self.pp_proxy_tensors = { | |
| "hidden_states": torch.zeros( | |
| (self.max_bs, self.model_runner.model_config.hidden_size), | |
| dtype=torch.bfloat16, | |
| ), | |
| "residual": torch.zeros( | |
| (self.max_bs, self.model_runner.model_config.hidden_size), | |
| dtype=torch.bfloat16, | |
| ), | |
| } | |
| # Speculative_inference | |
| if model_runner.spec_algorithm.is_eagle3(): | |
| self.model_runner.model.set_eagle3_layers_to_capture() | |
| if self.is_encoder_decoder: | |
| # NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch | |
| self.encoder_lens = torch.full( | |
| (self.max_bs,), self.encoder_len_fill_value, dtype=torch.int32 | |
| ) | |
| else: | |
| self.encoder_lens = None | |
| if self.require_gathered_buffer: | |
| if self.require_mlp_tp_gather: | |
| self.global_num_tokens_gpu = torch.zeros( | |
| (self.dp_size,), dtype=torch.int32 | |
| ) | |
| self.global_num_tokens_for_logprob_gpu = torch.zeros( | |
| (self.dp_size,), dtype=torch.int32 | |
| ) | |
| else: | |
| assert self.require_attn_tp_gather | |
| self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32) | |
| self.global_num_tokens_for_logprob_gpu = torch.zeros( | |
| (1,), dtype=torch.int32 | |
| ) | |
| else: | |
| self.global_num_tokens_gpu = None | |
| self.global_num_tokens_for_logprob_gpu = None | |
| self.custom_mask = torch.ones( | |
| ( | |
| (self.seq_lens.sum().item() + self.max_num_token) | |
| * self.num_tokens_per_bs | |
| ), | |
| dtype=torch.bool, | |
| device=self.device, | |
| ) | |
| self.next_token_logits_buffer = torch.zeros( | |
| (self.max_num_token, self.model_runner.model_config.vocab_size), | |
| dtype=torch.float, | |
| device=self.device, | |
| ) | |
| # Capture | |
| try: | |
| with model_capture_mode(): | |
| self.capture() | |
| except RuntimeError as e: | |
| raise Exception( | |
| f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}" | |
| ) | |
| def _cache_loc_dtype(self): | |
| return torch.int64 | |
| def can_run(self, forward_batch: ForwardBatch): | |
| if self.require_mlp_tp_gather: | |
| cuda_graph_bs = ( | |
| max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs | |
| if self.model_runner.spec_algorithm.is_eagle() | |
| else max(forward_batch.global_num_tokens_cpu) | |
| ) | |
| else: | |
| cuda_graph_bs = forward_batch.batch_size | |
| is_bs_supported = ( | |
| cuda_graph_bs in self.graphs | |
| if self.disable_padding | |
| else cuda_graph_bs <= self.max_bs | |
| ) | |
| if self.require_mlp_sync: | |
| is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph | |
| # NOTE: cuda graph cannot handle mixed batch (encoder_len = 0) | |
| # If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph | |
| # because the full_text_row_masked_out_mask tensor will always be ones | |
| is_encoder_lens_supported = ( | |
| torch.all(forward_batch.encoder_lens > 0) | |
| if self.is_encoder_decoder | |
| else True | |
| ) | |
| requested_capture_hidden_mode = max( | |
| forward_batch.capture_hidden_mode, | |
| ( | |
| forward_batch.spec_info.capture_hidden_mode | |
| if getattr(forward_batch.spec_info, "capture_hidden_mode", None) | |
| is not None | |
| else CaptureHiddenMode.NULL | |
| ), | |
| ) | |
| capture_hidden_mode_matches = ( | |
| requested_capture_hidden_mode == CaptureHiddenMode.NULL | |
| or requested_capture_hidden_mode == self.capture_hidden_mode | |
| ) | |
| is_tbo_supported = ( | |
| forward_batch.can_run_tbo if self.enable_two_batch_overlap else True | |
| ) | |
| is_ngram_supported = ( | |
| ( | |
| forward_batch.batch_size * self.num_tokens_per_bs | |
| == forward_batch.input_ids.numel() | |
| ) | |
| if self.model_runner.spec_algorithm.is_ngram() | |
| else True | |
| ) | |
| return ( | |
| is_bs_supported | |
| and is_encoder_lens_supported | |
| and is_tbo_supported | |
| and capture_hidden_mode_matches | |
| and is_ngram_supported | |
| ) | |
| def capture(self) -> None: | |
| profile_context = empty_context() | |
| if self.enable_profile_cuda_graph: | |
| profile_context = profile( | |
| activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], | |
| record_shapes=True, | |
| ) | |
| torch.cuda.memory._record_memory_history() | |
| # Trigger CUDA graph capture for specific shapes. | |
| # Capture the large shapes first so that the smaller shapes | |
| # can reuse the memory pool allocated for the large shapes. | |
| with freeze_gc( | |
| self.model_runner.server_args.enable_cudagraph_gc | |
| ), graph_capture() as graph_capture_context: | |
| with profile_context as prof: | |
| self.stream = graph_capture_context.stream | |
| avail_mem = get_available_gpu_memory( | |
| self.model_runner.device, | |
| self.model_runner.gpu_id, | |
| empty_cache=False, | |
| ) | |
| # Reverse the order to enable better memory sharing across cuda graphs. | |
| capture_range = ( | |
| tqdm.tqdm(list(reversed(self.capture_bs))) | |
| if get_tensor_model_parallel_rank() == 0 | |
| else reversed(self.capture_bs) | |
| ) | |
| for i, bs in enumerate(capture_range): | |
| if get_tensor_model_parallel_rank() == 0: | |
| avail_mem = get_available_gpu_memory( | |
| self.model_runner.device, | |
| self.model_runner.gpu_id, | |
| empty_cache=False, | |
| ) | |
| capture_range.set_description( | |
| f"Capturing batches ({bs=} {avail_mem=:.2f} GB)" | |
| ) | |
| with patch_model( | |
| self.model_runner.model, | |
| bs in self.compile_bs, | |
| num_tokens=bs * self.num_tokens_per_bs, | |
| tp_group=self.model_runner.tp_group, | |
| ) as forward: | |
| ( | |
| graph, | |
| output_buffers, | |
| ) = self.capture_one_batch_size(bs, forward) | |
| self.graphs[bs] = graph | |
| self.output_buffers[bs] = output_buffers | |
| # Save gemlite cache after each capture | |
| save_gemlite_cache() | |
| if self.enable_profile_cuda_graph: | |
| torch.cuda.memory._dump_snapshot(f"cuda_graph_runner_memory_usage.pickle") | |
| torch.cuda.memory._record_memory_history(enabled=None) | |
| log_message = ( | |
| "Sorted by CUDA Time:\n" | |
| + prof.key_averages(group_by_input_shape=True).table( | |
| sort_by="cuda_time_total", row_limit=10 | |
| ) | |
| + "\n\nSorted by CPU Time:\n" | |
| + prof.key_averages(group_by_input_shape=True).table( | |
| sort_by="cpu_time_total", row_limit=10 | |
| ) | |
| + "\n\nMemory Usage is saved to cuda_graph_runner_memory_usage.pickle\n" | |
| ) | |
| logger.info(log_message) | |
| def _capture_graph(self, graph, pool, stream, run_once_fn): | |
| with self.device_module.graph(graph, pool=pool, stream=stream): | |
| out = run_once_fn() | |
| return out | |
| def _create_device_graph(self): | |
| return torch.cuda.CUDAGraph() | |
| def capture_one_batch_size(self, bs: int, forward: Callable): | |
| graph = self._create_device_graph() | |
| stream = self.stream | |
| num_tokens = bs * self.num_tokens_per_bs | |
| # Graph inputs | |
| input_ids = self.input_ids[:num_tokens] | |
| req_pool_indices = self.req_pool_indices[:bs] | |
| seq_lens = self.seq_lens[:bs] | |
| seq_lens_cpu = self.seq_lens_cpu[:bs] | |
| out_cache_loc = self.out_cache_loc[:num_tokens] | |
| positions = self.positions[:num_tokens] | |
| if self.is_encoder_decoder: | |
| encoder_lens = self.encoder_lens[:bs] | |
| else: | |
| encoder_lens = None | |
| mrope_positions = self.mrope_positions[:, :num_tokens] | |
| next_token_logits_buffer = self.next_token_logits_buffer[:num_tokens] | |
| self.num_token_non_padded[...] = num_tokens | |
| # pipeline parallelism | |
| if self.pp_size > 1: | |
| pp_proxy_tensors = PPProxyTensors( | |
| {k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()} | |
| ) | |
| if self.require_mlp_tp_gather: | |
| self.global_num_tokens_gpu.copy_( | |
| torch.tensor( | |
| [num_tokens] * self.dp_size, | |
| dtype=torch.int32, | |
| device=input_ids.device, | |
| ) | |
| ) | |
| self.global_num_tokens_for_logprob_gpu.copy_( | |
| torch.tensor( | |
| [num_tokens] * self.dp_size, | |
| dtype=torch.int32, | |
| device=input_ids.device, | |
| ) | |
| ) | |
| global_dp_buffer_len = num_tokens * self.dp_size | |
| elif self.require_attn_tp_gather: | |
| self.global_num_tokens_gpu.copy_( | |
| torch.tensor( | |
| [num_tokens], | |
| dtype=torch.int32, | |
| device=input_ids.device, | |
| ) | |
| ) | |
| self.global_num_tokens_for_logprob_gpu.copy_( | |
| torch.tensor( | |
| [num_tokens], | |
| dtype=torch.int32, | |
| device=input_ids.device, | |
| ) | |
| ) | |
| global_dp_buffer_len = num_tokens | |
| else: | |
| global_dp_buffer_len = None | |
| spec_info = self.get_spec_info(num_tokens) | |
| if self.capture_hidden_mode != CaptureHiddenMode.FULL: | |
| self.capture_hidden_mode = ( | |
| spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL | |
| ) | |
| if self.model_runner.server_args.enable_lora: | |
| # It is safe to capture CUDA graph using empty LoRA id, as the LoRA kernels will always be launched whenever | |
| # `--enable-lora` is set to True (and return immediately if the LoRA id is empty for perf optimization). | |
| lora_ids = [None] * bs | |
| else: | |
| lora_ids = None | |
| forward_batch = ForwardBatch( | |
| forward_mode=self.capture_forward_mode, | |
| batch_size=bs, | |
| input_ids=input_ids, | |
| req_pool_indices=req_pool_indices, | |
| seq_lens=seq_lens, | |
| seq_lens_cpu=seq_lens_cpu, | |
| next_token_logits_buffer=next_token_logits_buffer, | |
| orig_seq_lens=seq_lens, | |
| req_to_token_pool=self.model_runner.req_to_token_pool, | |
| token_to_kv_pool=self.model_runner.token_to_kv_pool, | |
| attn_backend=self.model_runner.attn_backend, | |
| out_cache_loc=out_cache_loc, | |
| seq_lens_sum=seq_lens.sum().item(), | |
| encoder_lens=encoder_lens, | |
| return_logprob=False, | |
| positions=positions, | |
| global_num_tokens_gpu=self.global_num_tokens_gpu, | |
| global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu, | |
| dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(), | |
| global_dp_buffer_len=global_dp_buffer_len, | |
| mrope_positions=mrope_positions, | |
| spec_algorithm=self.model_runner.spec_algorithm, | |
| spec_info=spec_info, | |
| capture_hidden_mode=self.capture_hidden_mode, | |
| num_token_non_padded=self.num_token_non_padded, | |
| global_forward_mode=self.capture_forward_mode, | |
| lora_ids=lora_ids, | |
| ) | |
| self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens) | |
| if lora_ids is not None: | |
| self.model_runner.lora_manager.prepare_lora_batch(forward_batch) | |
| # Attention backend | |
| self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph( | |
| bs, | |
| num_tokens, | |
| req_pool_indices, | |
| seq_lens, | |
| encoder_lens, | |
| forward_batch.forward_mode, | |
| forward_batch.spec_info, | |
| ) | |
| # Run and capture | |
| def run_once(): | |
| # Clean intermediate result cache for DP attention | |
| forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None | |
| set_dp_buffer_len(global_dp_buffer_len, num_tokens) | |
| set_is_extend_in_batch(False) | |
| kwargs = {} | |
| if ( | |
| self.pp_size > 1 | |
| and "pp_proxy_tensors" in inspect.signature(forward).parameters | |
| ): | |
| kwargs["pp_proxy_tensors"] = PPProxyTensors( | |
| {k: v.clone() for k, v in pp_proxy_tensors.tensors.items()} | |
| ) | |
| logits_output_or_pp_proxy_tensors = forward( | |
| input_ids, | |
| forward_batch.positions, | |
| forward_batch, | |
| **kwargs, | |
| ) | |
| return logits_output_or_pp_proxy_tensors | |
| self.deepep_adapter.capture(is_extend_in_batch=False) | |
| for _ in range(2): | |
| self.device_module.synchronize() | |
| self.model_runner.tp_group.barrier() | |
| run_once() | |
| if get_global_graph_memory_pool() is None: | |
| set_global_graph_memory_pool(self.device_module.graph_pool_handle()) | |
| # Set graph pool id globally to be able to use symmetric memory | |
| set_graph_pool_id(get_global_graph_memory_pool()) | |
| out = self._capture_graph( | |
| graph, get_global_graph_memory_pool(), stream, run_once | |
| ) | |
| return graph, out | |
| def recapture_if_needed(self, forward_batch: ForwardBatch): | |
| # If the required capture_hidden_mode changes, we need to recapture the graph | |
| # These are the different factors that can influence the capture_hidden_mode | |
| capture_hidden_mode_required_by_forward_batch = ( | |
| forward_batch.capture_hidden_mode | |
| ) | |
| capture_hidden_mode_required_by_spec_info = ( | |
| getattr(forward_batch.spec_info, "capture_hidden_mode", None) | |
| or CaptureHiddenMode.NULL | |
| ) | |
| capture_hidden_mode_required_for_returning_hidden_states = ( | |
| CaptureHiddenMode.FULL | |
| if self.model_runner.server_args.enable_return_hidden_states | |
| else CaptureHiddenMode.NULL | |
| ) | |
| # Determine the highest capture_hidden_mode required | |
| # (If we have FULL, we can emulate LAST or NULL) | |
| # (If we have LAST, we can emulate NULL) | |
| required_capture_hidden_mode = max( | |
| capture_hidden_mode_required_by_forward_batch, | |
| capture_hidden_mode_required_by_spec_info, | |
| capture_hidden_mode_required_for_returning_hidden_states, | |
| ) | |
| # If the current hidden mode is no longer aligned with the required hidden mode, we need to set it to what is required and re-capture | |
| if self.capture_hidden_mode != required_capture_hidden_mode: | |
| self.capture_hidden_mode = required_capture_hidden_mode | |
| self.capture() | |
| def replay_prepare( | |
| self, | |
| forward_batch: ForwardBatch, | |
| pp_proxy_tensors: Optional[PPProxyTensors] = None, | |
| ): | |
| self.recapture_if_needed(forward_batch) | |
| raw_bs = forward_batch.batch_size | |
| raw_num_token = raw_bs * self.num_tokens_per_bs | |
| # Pad | |
| if self.require_mlp_tp_gather: | |
| max_num_tokens = max(forward_batch.global_num_tokens_cpu) | |
| max_batch_size = ( | |
| max_num_tokens / self.num_tokens_per_bs | |
| if self.model_runner.spec_algorithm.is_eagle() | |
| else max_num_tokens | |
| ) | |
| index = bisect.bisect_left(self.capture_bs, max_batch_size) | |
| else: | |
| index = bisect.bisect_left(self.capture_bs, raw_bs) | |
| bs = self.capture_bs[index] | |
| if bs != raw_bs: | |
| self.seq_lens.fill_(self.seq_len_fill_value) | |
| self.out_cache_loc.zero_() | |
| # Common inputs | |
| self.input_ids[:raw_num_token].copy_(forward_batch.input_ids) | |
| self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) | |
| self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) | |
| self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc) | |
| self.positions[:raw_num_token].copy_(forward_batch.positions) | |
| seq_lens_cpu = None | |
| if forward_batch.seq_lens_cpu is not None: | |
| if bs != raw_bs: | |
| self.seq_lens_cpu.fill_(self.seq_len_fill_value) | |
| self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) | |
| seq_lens_cpu = self.seq_lens_cpu[:bs] | |
| if pp_proxy_tensors: | |
| for key in self.pp_proxy_tensors.keys(): | |
| dim = pp_proxy_tensors[key].shape[0] | |
| self.pp_proxy_tensors[key][:dim].copy_(pp_proxy_tensors[key]) | |
| if self.is_encoder_decoder: | |
| self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) | |
| if forward_batch.mrope_positions is not None: | |
| self.mrope_positions[:, :raw_num_token].copy_(forward_batch.mrope_positions) | |
| if self.require_gathered_buffer: | |
| self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs) | |
| self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs) | |
| if enable_num_token_non_padded(self.model_runner.server_args): | |
| num_token_non_padded = forward_batch.num_token_non_padded | |
| if self.require_gathered_buffer: | |
| tokens_per_rank = bs // self.attn_tp_size * self.num_tokens_per_bs | |
| num_local_token_non_padded = torch.clamp( | |
| num_token_non_padded - tokens_per_rank * self.attn_tp_rank, | |
| min=0, | |
| max=tokens_per_rank, | |
| ) | |
| self.num_token_non_padded.copy_(num_local_token_non_padded) | |
| else: | |
| self.num_token_non_padded.copy_(num_token_non_padded) | |
| if self.enable_two_batch_overlap: | |
| self.tbo_plugin.replay_prepare( | |
| forward_mode=self.capture_forward_mode, | |
| bs=bs, | |
| num_token_non_padded=len(forward_batch.input_ids), | |
| spec_info=forward_batch.spec_info, | |
| ) | |
| if forward_batch.forward_mode.is_idle() and forward_batch.spec_info is not None: | |
| forward_batch.spec_info.custom_mask = self.custom_mask | |
| # Attention backend | |
| self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( | |
| bs, | |
| self.req_pool_indices[:bs], | |
| self.seq_lens[:bs], | |
| forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value, | |
| self.encoder_lens[:bs] if self.is_encoder_decoder else None, | |
| self.capture_forward_mode, | |
| forward_batch.spec_info, | |
| seq_lens_cpu=seq_lens_cpu, | |
| ) | |
| # Store fields | |
| self.raw_bs = raw_bs | |
| self.raw_num_token = raw_num_token | |
| self.bs = bs | |
| def replay( | |
| self, | |
| forward_batch: ForwardBatch, | |
| skip_attn_backend_init: bool = False, | |
| pp_proxy_tensors: Optional[PPProxyTensors] = None, | |
| ) -> Union[LogitsProcessorOutput, PPProxyTensors]: | |
| self.deepep_adapter.replay() | |
| if not skip_attn_backend_init: | |
| self.replay_prepare(forward_batch, pp_proxy_tensors) | |
| else: | |
| # In speculative decoding, these two fields are still needed. | |
| self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids) | |
| self.positions[: self.raw_num_token].copy_(forward_batch.positions) | |
| # Replay | |
| self.graphs[self.bs].replay() | |
| output = self.output_buffers[self.bs] | |
| if isinstance(output, LogitsProcessorOutput): | |
| return LogitsProcessorOutput( | |
| next_token_logits=output.next_token_logits[: self.raw_num_token], | |
| hidden_states=( | |
| output.hidden_states[: self.raw_num_token] | |
| if output.hidden_states is not None | |
| else None | |
| ), | |
| ) | |
| else: | |
| assert isinstance(output, PPProxyTensors) | |
| return PPProxyTensors({k: v[: self.bs] for k, v in output.tensors.items()}) | |
| def get_spec_info(self, num_tokens: int): | |
| spec_info = None | |
| if ( | |
| self.model_runner.spec_algorithm.is_eagle() | |
| or self.model_runner.spec_algorithm.is_standalone() | |
| ): | |
| from sglang.srt.speculative.eagle_info import EagleVerifyInput | |
| if self.model_runner.is_draft_worker: | |
| raise RuntimeError("This should not happen.") | |
| else: | |
| spec_info = EagleVerifyInput( | |
| draft_token=None, | |
| custom_mask=self.custom_mask, | |
| positions=None, | |
| retrive_index=None, | |
| retrive_next_token=None, | |
| retrive_next_sibling=None, | |
| retrive_cum_len=None, | |
| spec_steps=self.model_runner.server_args.speculative_num_steps, | |
| topk=self.model_runner.server_args.speculative_eagle_topk, | |
| draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens, | |
| capture_hidden_mode=CaptureHiddenMode.FULL, | |
| seq_lens_sum=None, | |
| seq_lens_cpu=None, | |
| ) | |
| elif self.model_runner.spec_algorithm.is_ngram(): | |
| from sglang.srt.speculative.ngram_info import NgramVerifyInput | |
| spec_info = NgramVerifyInput( | |
| draft_token=None, | |
| tree_mask=self.custom_mask, | |
| positions=None, | |
| retrive_index=None, | |
| retrive_next_token=None, | |
| retrive_next_sibling=None, | |
| draft_token_num=self.num_tokens_per_bs, | |
| ) | |
| spec_info.capture_hidden_mode = CaptureHiddenMode.NULL | |
| return spec_info | |
| CUDA_GRAPH_CAPTURE_FAILED_MSG = ( | |
| "Possible solutions:\n" | |
| "1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n" | |
| "2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n" | |
| "3. disable torch compile by not using --enable-torch-compile\n" | |
| "4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n" | |
| "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n" | |
| ) | |
| class DeepEPCudaGraphRunnerAdapter: | |
| def __init__(self): | |
| # Record DeepEP mode used during capture to ensure replay consistency | |
| self._captured_deepep_mode = None | |
| def capture(self, is_extend_in_batch: bool): | |
| if not get_moe_a2a_backend().is_deepep(): | |
| return | |
| self._captured_deepep_mode = get_deepep_mode().resolve( | |
| is_extend_in_batch=is_extend_in_batch | |
| ) | |
| DeepEPBuffer.set_dispatch_mode(self._captured_deepep_mode) | |
| def replay(self): | |
| if not get_moe_a2a_backend().is_deepep(): | |
| return | |
| assert self._captured_deepep_mode is not None | |
| DeepEPBuffer.set_dispatch_mode(self._captured_deepep_mode) | |
Xet Storage Details
- Size:
- 35.6 kB
- Xet hash:
- a562a3ec6165389e9a48cec0068b301305bb2fe0c9def038946aae3dfd624bbe
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.