| # Copyright 2023-2024 SGLang Team | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Run the model with cuda graph and torch.compile.""" | |
| from __future__ import annotations | |
| import bisect | |
| import gc | |
| import logging | |
| from contextlib import contextmanager | |
| from typing import TYPE_CHECKING, Union | |
| import torch | |
| import tqdm | |
| from sglang.srt.compilation.compilation_config import CompilationConfig | |
| from sglang.srt.compilation.compile import install_torch_compiled, set_compiled | |
| from sglang.srt.compilation.piecewise_context_manager import set_forward_context | |
| from sglang.srt.custom_op import CustomOp | |
| from sglang.srt.distributed import get_tensor_model_parallel_rank | |
| from sglang.srt.distributed.device_communicators.pynccl_allocator import ( | |
| set_graph_pool_id, | |
| ) | |
| from sglang.srt.distributed.parallel_state import graph_capture | |
| from sglang.srt.layers.dp_attention import ( | |
| DpPaddingMode, | |
| get_attention_tp_rank, | |
| get_attention_tp_size, | |
| set_dp_buffer_len, | |
| set_is_extend_in_batch, | |
| ) | |
| from sglang.srt.layers.logits_processor import LogitsProcessorOutput | |
| from sglang.srt.layers.torchao_utils import save_gemlite_cache | |
| from sglang.srt.model_executor.forward_batch_info import ( | |
| CaptureHiddenMode, | |
| ForwardBatch, | |
| ForwardMode, | |
| PPProxyTensors, | |
| ) | |
| from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin | |
| from sglang.srt.utils import get_available_gpu_memory, log_info_on_rank0 | |
| logger = logging.getLogger(__name__) | |
| if TYPE_CHECKING: | |
| from sglang.srt.model_executor.model_runner import ModelRunner | |
| # Detect whether the current forward pass is in capture mode | |
| is_capture_mode = False | |
| def get_is_capture_mode(): | |
| return is_capture_mode | |
| def model_capture_mode(): | |
| global is_capture_mode | |
| is_capture_mode = True | |
| yield | |
| is_capture_mode = False | |
| def freeze_gc(enable_cudagraph_gc: bool): | |
| """ | |
| Optimize garbage collection during CUDA graph capture. | |
| Clean up, then freeze all remaining objects from being included | |
| in future collections if GC is disabled during capture. | |
| """ | |
| gc.collect() | |
| should_freeze = not enable_cudagraph_gc | |
| if should_freeze: | |
| gc.freeze() | |
| try: | |
| yield | |
| finally: | |
| if should_freeze: | |
| gc.unfreeze() | |
| def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int): | |
| for sub in model._modules.values(): | |
| if isinstance(sub, CustomOp): | |
| if reverse: | |
| sub.leave_torch_compile() | |
| else: | |
| sub.enter_torch_compile(num_tokens=num_tokens) | |
| if isinstance(sub, torch.nn.Module): | |
| _to_torch(sub, reverse, num_tokens) | |
| def patch_model(model: torch.nn.Module, compiler: str): | |
| try: | |
| if compiler != "eager": | |
| _to_torch(model, reverse=False, num_tokens=16) | |
| yield model | |
| finally: | |
| _to_torch(model, reverse=True, num_tokens=16) | |
| # Reuse this memory pool across all cuda graph runners. | |
| global_graph_memory_pool = None | |
| def get_global_graph_memory_pool(): | |
| return global_graph_memory_pool | |
| def set_global_graph_memory_pool(val): | |
| global global_graph_memory_pool | |
| global_graph_memory_pool = val | |
| class PiecewiseCudaGraphRunner: | |
| """A PiecewiseCudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile.""" | |
| def __init__(self, model_runner: ModelRunner): | |
| # Parse args | |
| self.model_runner = model_runner | |
| self.device = model_runner.device | |
| self.device_module = torch.get_device_module(self.device) | |
| self.graphs = {} | |
| self.output_buffers = {} | |
| self.tp_size = model_runner.server_args.tp_size | |
| self.dp_size = model_runner.server_args.dp_size | |
| self.pp_size = model_runner.server_args.pp_size | |
| self.attn_tp_size = get_attention_tp_size() | |
| self.attn_tp_rank = get_attention_tp_rank() | |
| assert ( | |
| self.model_runner.server_args.piecewise_cuda_graph_tokens is not None | |
| ), "piecewise_cuda_graph_tokens is not set" | |
| assert self.model_runner.server_args.piecewise_cuda_graph_compiler in [ | |
| "eager", | |
| "inductor", | |
| ], "By now, only eager and inductor are supported for piecewise cuda graph compiler." | |
| self.compile_config = CompilationConfig( | |
| self.model_runner.server_args.piecewise_cuda_graph_tokens, | |
| self.model_runner.server_args.piecewise_cuda_graph_compiler, | |
| ) | |
| # Batch sizes to capture | |
| self.capture_num_tokens = self.compile_config.get_capture_sizes() | |
| log_info_on_rank0( | |
| logger, f"Capture cuda graph num tokens {self.capture_num_tokens}" | |
| ) | |
| self.capture_forward_mode = ForwardMode.EXTEND | |
| self.capture_hidden_mode = CaptureHiddenMode.NULL | |
| # If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup | |
| if model_runner.server_args.enable_return_hidden_states: | |
| self.capture_hidden_mode = CaptureHiddenMode.FULL | |
| # Attention backend | |
| self.max_num_tokens = max(self.capture_num_tokens) | |
| # Graph inputs | |
| with torch.device(self.device): | |
| self.input_ids = torch.zeros((self.max_num_tokens,), dtype=torch.int64) | |
| self.out_cache_loc = torch.zeros( | |
| (self.max_num_tokens,), dtype=self._cache_loc_dtype() | |
| ) | |
| self.positions = torch.zeros((self.max_num_tokens,), dtype=torch.int64) | |
| self.tbo_plugin = TboCudaGraphRunnerPlugin() | |
| self.attention_layers = self.model_runner.attention_layers | |
| if get_global_graph_memory_pool() is None: | |
| set_global_graph_memory_pool(self.device_module.graph_pool_handle()) | |
| # Set graph pool id globally to be able to use symmetric memory | |
| set_graph_pool_id(get_global_graph_memory_pool()) | |
| with patch_model( | |
| self.model_runner.model.model, self.compile_config.compiler | |
| ) as patched_model: | |
| install_torch_compiled( | |
| patched_model, | |
| fullgraph=True, | |
| dynamic_arg_dims=None, | |
| compile_config=self.compile_config, | |
| graph_pool=get_global_graph_memory_pool(), | |
| ) | |
| with set_compiled(True): | |
| self.warmup_and_capture() | |
| # Capture | |
| try: | |
| with model_capture_mode(): | |
| self.capture() | |
| except RuntimeError as e: | |
| raise Exception( | |
| f"Capture cuda graph failed: {e}\n{PIECEWISE_CUDA_GRAPH_CAPTURE_FAILED_MSG}" | |
| ) | |
| self.raw_num_tokens = 0 | |
| def warmup_and_capture(self): | |
| num_tokens = 2 | |
| with torch.device(self.device): | |
| forward_batch = ForwardBatch( | |
| forward_mode=ForwardMode.EXTEND, | |
| batch_size=1, | |
| input_ids=torch.randint(0, 100, (num_tokens,), device=self.device), | |
| req_pool_indices=torch.arange(1, device=self.device), | |
| seq_lens=torch.tensor([num_tokens], device=self.device), | |
| next_token_logits_buffer=None, | |
| orig_seq_lens=torch.tensor([num_tokens], device=self.device), | |
| seq_lens_cpu=torch.tensor([num_tokens]), | |
| req_to_token_pool=self.model_runner.req_to_token_pool, | |
| token_to_kv_pool=self.model_runner.token_to_kv_pool, | |
| attn_backend=self.model_runner.attn_backend, | |
| out_cache_loc=torch.randint(0, 100, (num_tokens,), device=self.device), | |
| seq_lens_sum=num_tokens, | |
| encoder_lens=None, | |
| return_logprob=False, | |
| extend_seq_lens=torch.tensor([num_tokens], device=self.device), | |
| extend_prefix_lens=torch.tensor([num_tokens], device=self.device), | |
| extend_start_loc=torch.tensor([0], device=self.device), | |
| extend_prefix_lens_cpu=torch.tensor([num_tokens]), | |
| extend_seq_lens_cpu=torch.tensor([num_tokens]), | |
| extend_logprob_start_lens_cpu=torch.tensor([num_tokens]), | |
| positions=torch.arange(num_tokens, device=self.device), | |
| global_num_tokens_gpu=None, | |
| global_num_tokens_for_logprob_gpu=None, | |
| dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(), | |
| global_dp_buffer_len=None, | |
| mrope_positions=None, | |
| spec_algorithm=None, | |
| spec_info=None, | |
| capture_hidden_mode=CaptureHiddenMode.NULL, | |
| num_token_non_padded=None, | |
| global_forward_mode=ForwardMode.EXTEND, | |
| lora_ids=None, | |
| ) | |
| with set_forward_context(forward_batch, self.attention_layers): | |
| _ = self.model_runner.model.forward( | |
| forward_batch.input_ids, | |
| forward_batch.positions, | |
| forward_batch, | |
| ) | |
| def _cache_loc_dtype(self): | |
| return torch.int64 | |
| def can_run(self, forward_batch: ForwardBatch): | |
| num_tokens = len(forward_batch.input_ids) | |
| # TODO(yuwei): support return logprob | |
| if forward_batch.return_logprob: | |
| return False | |
| if num_tokens <= self.max_num_tokens: | |
| return True | |
| return False | |
| def capture(self) -> None: | |
| # Trigger CUDA graph capture for specific shapes. | |
| # Capture the large shapes first so that the smaller shapes | |
| # can reuse the memory pool allocated for the large shapes. | |
| with freeze_gc( | |
| self.model_runner.server_args.enable_cudagraph_gc | |
| ), graph_capture() as graph_capture_context: | |
| self.stream = graph_capture_context.stream | |
| avail_mem = get_available_gpu_memory( | |
| self.model_runner.device, | |
| self.model_runner.gpu_id, | |
| empty_cache=False, | |
| ) | |
| # Reverse the order to enable better memory sharing across cuda graphs. | |
| capture_range = ( | |
| tqdm.tqdm(list(reversed(self.capture_num_tokens))) | |
| if get_tensor_model_parallel_rank() == 0 | |
| else reversed(self.capture_num_tokens) | |
| ) | |
| for i, num_tokens in enumerate(capture_range): | |
| if get_tensor_model_parallel_rank() == 0: | |
| avail_mem = get_available_gpu_memory( | |
| self.model_runner.device, | |
| self.model_runner.gpu_id, | |
| empty_cache=False, | |
| ) | |
| capture_range.set_description( | |
| f"Capturing num tokens ({num_tokens=} {avail_mem=:.2f} GB)" | |
| ) | |
| with set_compiled(True): | |
| self.capture_one_batch_size(num_tokens) | |
| # Save gemlite cache after each capture | |
| save_gemlite_cache() | |
| def capture_one_batch_size(self, num_tokens: int): | |
| stream = self.stream | |
| bs = 1 | |
| # Graph inputs | |
| input_ids = self.input_ids[:num_tokens] | |
| out_cache_loc = self.out_cache_loc[:num_tokens] | |
| positions = self.positions[:num_tokens] | |
| # pipeline parallelism | |
| if self.pp_size > 1: | |
| pp_proxy_tensors = PPProxyTensors( | |
| {k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()} | |
| ) | |
| global_dp_buffer_len = None | |
| if self.model_runner.server_args.enable_lora: | |
| # It is safe to capture CUDA graph using empty LoRA id, as the LoRA kernels will always be launched whenever | |
| # `--enable-lora` is set to True (and return immediately if the LoRA id is empty for perf optimization). | |
| lora_ids = [None] * bs | |
| else: | |
| lora_ids = None | |
| with torch.device(self.device): | |
| forward_batch = ForwardBatch( | |
| forward_mode=ForwardMode.EXTEND, | |
| batch_size=bs, | |
| input_ids=input_ids, | |
| req_pool_indices=torch.arange(bs, device=self.device), | |
| seq_lens=torch.tensor([num_tokens], device=self.device), | |
| next_token_logits_buffer=None, | |
| orig_seq_lens=torch.tensor([num_tokens], device=self.device), | |
| seq_lens_cpu=torch.tensor([num_tokens]), | |
| req_to_token_pool=self.model_runner.req_to_token_pool, | |
| token_to_kv_pool=self.model_runner.token_to_kv_pool, | |
| attn_backend=self.model_runner.attn_backend, | |
| out_cache_loc=out_cache_loc, | |
| seq_lens_sum=num_tokens, | |
| encoder_lens=None, | |
| return_logprob=False, | |
| extend_seq_lens=torch.tensor([num_tokens], device=self.device), | |
| extend_prefix_lens=torch.tensor([num_tokens], device=self.device), | |
| extend_start_loc=torch.tensor([0], device=self.device), | |
| extend_prefix_lens_cpu=torch.tensor([num_tokens]), | |
| extend_seq_lens_cpu=torch.tensor([num_tokens]), | |
| extend_logprob_start_lens_cpu=torch.tensor([num_tokens]), | |
| positions=positions, | |
| global_num_tokens_gpu=None, | |
| global_num_tokens_for_logprob_gpu=None, | |
| dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(), | |
| global_dp_buffer_len=None, | |
| mrope_positions=None, | |
| spec_algorithm=None, | |
| spec_info=None, | |
| capture_hidden_mode=CaptureHiddenMode.NULL, | |
| num_token_non_padded=None, | |
| global_forward_mode=ForwardMode.EXTEND, | |
| lora_ids=None, | |
| ) | |
| self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens) | |
| if lora_ids is not None: | |
| self.model_runner.lora_manager.prepare_lora_batch(forward_batch) | |
| # # Attention backend | |
| self.model_runner.attn_backend.init_forward_metadata(forward_batch) | |
| # Run and capture | |
| def run_once(): | |
| # Clean intermediate result cache for DP attention | |
| forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None | |
| set_dp_buffer_len(global_dp_buffer_len, num_tokens) | |
| # FIXME: the implementation is hacky. `is_extend_in_batch`` is for determining the deepep mode. | |
| # It is True in this context but we need to set it to use low latency deepep mode. | |
| set_is_extend_in_batch(False) | |
| kwargs = {} | |
| with set_forward_context(forward_batch, self.attention_layers): | |
| self.model_runner.model.forward( | |
| forward_batch.input_ids, | |
| forward_batch.positions, | |
| forward_batch, | |
| **kwargs, | |
| ) | |
| return | |
| for _ in range(2): | |
| self.device_module.synchronize() | |
| self.model_runner.tp_group.barrier() | |
| run_once() | |
| return | |
| def replay_prepare( | |
| self, | |
| forward_batch: ForwardBatch, | |
| **kwargs, | |
| ): | |
| num_tokens = len(forward_batch.input_ids) | |
| index = bisect.bisect_left(self.capture_num_tokens, num_tokens) | |
| static_num_tokens = self.capture_num_tokens[index] | |
| self.raw_num_tokens = num_tokens | |
| if static_num_tokens != num_tokens: | |
| self.out_cache_loc.zero_() | |
| bs = forward_batch.batch_size | |
| self.input_ids[:num_tokens].copy_(forward_batch.input_ids) | |
| self.positions[:num_tokens].copy_(forward_batch.positions) | |
| self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc) | |
| input_ids = self.input_ids[:static_num_tokens] | |
| positions = self.positions[:static_num_tokens] | |
| out_cache_loc = self.out_cache_loc[:static_num_tokens] | |
| next_token_logits_buffer = None | |
| mrope_positions = None | |
| static_forward_batch = ForwardBatch( | |
| forward_mode=forward_batch.forward_mode, | |
| batch_size=bs, | |
| input_ids=input_ids, | |
| req_pool_indices=forward_batch.req_pool_indices, | |
| seq_lens=forward_batch.seq_lens, | |
| next_token_logits_buffer=next_token_logits_buffer, | |
| orig_seq_lens=forward_batch.orig_seq_lens, | |
| seq_lens_cpu=forward_batch.seq_lens_cpu, | |
| req_to_token_pool=self.model_runner.req_to_token_pool, | |
| token_to_kv_pool=self.model_runner.token_to_kv_pool, | |
| attn_backend=self.model_runner.attn_backend, | |
| out_cache_loc=out_cache_loc, | |
| seq_lens_sum=forward_batch.seq_lens_sum, | |
| encoder_lens=forward_batch.encoder_lens, | |
| return_logprob=forward_batch.return_logprob, | |
| extend_seq_lens=forward_batch.extend_seq_lens, | |
| extend_prefix_lens=forward_batch.extend_prefix_lens, | |
| extend_start_loc=forward_batch.extend_start_loc, | |
| extend_prefix_lens_cpu=forward_batch.extend_prefix_lens_cpu, | |
| extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu, | |
| extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu, | |
| extend_num_tokens=forward_batch.extend_num_tokens, | |
| extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu, | |
| positions=positions, | |
| global_num_tokens_gpu=forward_batch.global_num_tokens_gpu, | |
| global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu, | |
| dp_padding_mode=forward_batch.dp_padding_mode, | |
| global_dp_buffer_len=forward_batch.global_dp_buffer_len, | |
| mrope_positions=mrope_positions, | |
| spec_algorithm=forward_batch.spec_algorithm, | |
| spec_info=forward_batch.spec_info, | |
| capture_hidden_mode=forward_batch.capture_hidden_mode, | |
| num_token_non_padded=forward_batch.num_token_non_padded, | |
| global_forward_mode=forward_batch.global_forward_mode, | |
| lora_ids=forward_batch.lora_ids, | |
| sampling_info=forward_batch.sampling_info, | |
| mm_inputs=forward_batch.mm_inputs, | |
| temp_scaled_logprobs=forward_batch.temp_scaled_logprobs, | |
| temperature=forward_batch.temperature, | |
| top_p_normalized_logprobs=forward_batch.top_p_normalized_logprobs, | |
| top_p=forward_batch.top_p, | |
| ) | |
| return static_forward_batch | |
| def replay( | |
| self, | |
| forward_batch: ForwardBatch, | |
| **kwargs, | |
| ) -> Union[LogitsProcessorOutput, PPProxyTensors]: | |
| static_forward_batch = self.replay_prepare(forward_batch, **kwargs) | |
| # Replay | |
| with set_forward_context(static_forward_batch, self.attention_layers): | |
| with set_compiled(True): | |
| output = self.model_runner.model.forward( | |
| static_forward_batch.input_ids, | |
| static_forward_batch.positions, | |
| static_forward_batch, | |
| **kwargs, | |
| ) | |
| if isinstance(output, LogitsProcessorOutput): | |
| return LogitsProcessorOutput( | |
| next_token_logits=output.next_token_logits[: self.raw_num_tokens], | |
| hidden_states=( | |
| output.hidden_states[: self.raw_num_tokens] | |
| if output.hidden_states is not None | |
| else None | |
| ), | |
| ) | |
| else: | |
| assert isinstance(output, PPProxyTensors) | |
| # TODO(Yuwei): support PP Support | |
| raise NotImplementedError( | |
| "PPProxyTensors is not supported in PiecewiseCudaGraphRunner yet." | |
| ) | |
| def get_spec_info(self, num_tokens: int): | |
| spec_info = None | |
| if ( | |
| self.model_runner.spec_algorithm.is_eagle() | |
| or self.model_runner.spec_algorithm.is_standalone() | |
| ): | |
| from sglang.srt.speculative.eagle_utils import EagleVerifyInput | |
| if self.model_runner.is_draft_worker: | |
| raise RuntimeError("This should not happen.") | |
| else: | |
| spec_info = EagleVerifyInput( | |
| draft_token=None, | |
| custom_mask=self.custom_mask, | |
| positions=None, | |
| retrive_index=None, | |
| retrive_next_token=None, | |
| retrive_next_sibling=None, | |
| retrive_cum_len=None, | |
| spec_steps=self.model_runner.server_args.speculative_num_steps, | |
| topk=self.model_runner.server_args.speculative_eagle_topk, | |
| draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens, | |
| capture_hidden_mode=CaptureHiddenMode.FULL, | |
| seq_lens_sum=None, | |
| seq_lens_cpu=None, | |
| ) | |
| return spec_info | |
| PIECEWISE_CUDA_GRAPH_CAPTURE_FAILED_MSG = ( | |
| "Possible solutions:\n" | |
| "1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n" | |
| "2. set --piecewise-cuda-graph-max-tokens to a smaller value (e.g., 512)\n" | |
| "3. disable Piecewise CUDA graph by unset --enable-piecewise-cuda-graph\n" | |
| "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n" | |
| ) | |
Xet Storage Details
- Size:
- 22 kB
- Xet hash:
- 5ff709b85ee149a96acf7ddf8fd25c7a6e3f04752b3b37de3bcbb68c0149bed4
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.