| from __future__ import annotations | |
| import bisect | |
| from typing import TYPE_CHECKING, Callable | |
| import torch | |
| from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len | |
| from sglang.srt.model_executor.cuda_graph_runner import ( | |
| CUDA_GRAPH_CAPTURE_FAILED_MSG, | |
| CudaGraphRunner, | |
| DeepEPCudaGraphRunnerAdapter, | |
| get_batch_sizes_to_capture, | |
| get_global_graph_memory_pool, | |
| model_capture_mode, | |
| set_global_graph_memory_pool, | |
| set_is_extend_in_batch, | |
| set_torch_compile_config, | |
| ) | |
| from sglang.srt.model_executor.forward_batch_info import ( | |
| CaptureHiddenMode, | |
| ForwardBatch, | |
| ForwardMode, | |
| ) | |
| from sglang.srt.speculative.eagle_info import EagleDraftInput | |
| from sglang.srt.utils import ( | |
| require_attn_tp_gather, | |
| require_gathered_buffer, | |
| require_mlp_sync, | |
| require_mlp_tp_gather, | |
| ) | |
| if TYPE_CHECKING: | |
| from sglang.srt.speculative.eagle_worker import EAGLEWorker | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class EAGLEDraftCudaGraphRunner: | |
| def __init__(self, eagle_worker: EAGLEWorker): | |
| # Parse args | |
| self.eagle_worker = eagle_worker | |
| if not hasattr(eagle_worker, "model_runner"): | |
| # V2: EagleDraftWorker | |
| self.model_runner = model_runner = eagle_worker.draft_runner | |
| else: | |
| self.model_runner = model_runner = eagle_worker.model_runner | |
| self.graphs = {} | |
| self.output_buffers = {} | |
| self.enable_torch_compile = model_runner.server_args.enable_torch_compile | |
| self.disable_padding = model_runner.server_args.disable_cuda_graph_padding | |
| self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder | |
| self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args) | |
| self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args) | |
| self.require_mlp_sync = require_mlp_sync(model_runner.server_args) | |
| self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args) | |
| self.dp_size = self.model_runner.dp_size | |
| self.tp_size = self.model_runner.tp_size | |
| self.topk = model_runner.server_args.speculative_eagle_topk | |
| self.speculative_num_steps = model_runner.server_args.speculative_num_steps | |
| self.enable_profile_cuda_graph = ( | |
| model_runner.server_args.enable_profile_cuda_graph | |
| ) | |
| self.deepep_adapter = DeepEPCudaGraphRunnerAdapter() | |
| server_args = model_runner.server_args | |
| # Batch sizes to capture | |
| self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) | |
| self.num_tokens_per_bs = server_args.speculative_eagle_topk | |
| # Attention backend | |
| self.max_bs = max(self.capture_bs) | |
| self.max_num_token = self.max_bs * self.num_tokens_per_bs | |
| self.model_runner.draft_attn_backend.init_cuda_graph_state( | |
| self.max_bs, self.max_num_token | |
| ) | |
| self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[ | |
| 0 | |
| ].get_cuda_graph_seq_len_fill_value() | |
| self.seq_lens_cpu = torch.full( | |
| (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 | |
| ) | |
| self.extend_seq_lens_cpu = [self.seq_len_fill_value] * self.max_bs | |
| if self.enable_torch_compile: | |
| set_torch_compile_config() | |
| # Graph inputs | |
| with torch.device("cuda"): | |
| self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) | |
| self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32) | |
| self.seq_lens = torch.full( | |
| (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 | |
| ) | |
| self.extend_seq_lens = torch.ones((self.max_bs,), dtype=torch.int32) | |
| self.out_cache_loc = torch.zeros( | |
| (self.max_num_token * self.speculative_num_steps,), dtype=torch.int64 | |
| ) | |
| self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) | |
| self.mrope_positions = torch.zeros( | |
| (3, self.max_num_token), dtype=torch.int64 | |
| ) | |
| self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32) | |
| self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64) | |
| self.hidden_states = torch.zeros( | |
| (self.max_bs, self.model_runner.model_config.hidden_size), | |
| dtype=self.model_runner.dtype, | |
| ) | |
| if self.require_gathered_buffer: | |
| if self.require_mlp_tp_gather: | |
| self.global_num_tokens_gpu = torch.zeros( | |
| (self.dp_size,), dtype=torch.int32 | |
| ) | |
| self.global_num_tokens_for_logprob_gpu = torch.zeros( | |
| (self.dp_size,), dtype=torch.int32 | |
| ) | |
| else: | |
| assert self.require_attn_tp_gather | |
| self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32) | |
| self.global_num_tokens_for_logprob_gpu = torch.zeros( | |
| (1,), dtype=torch.int32 | |
| ) | |
| else: | |
| self.global_num_tokens_gpu = None | |
| self.global_num_tokens_for_logprob_gpu = None | |
| # Capture | |
| try: | |
| with model_capture_mode(): | |
| self.capture() | |
| except RuntimeError as e: | |
| raise Exception( | |
| f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}" | |
| ) | |
| def can_run(self, forward_batch: ForwardBatch): | |
| if self.require_mlp_tp_gather: | |
| cuda_graph_bs = ( | |
| max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs | |
| if self.model_runner.spec_algorithm.is_eagle() | |
| else max(forward_batch.global_num_tokens_cpu) | |
| ) | |
| else: | |
| cuda_graph_bs = forward_batch.batch_size | |
| is_bs_supported = ( | |
| cuda_graph_bs in self.graphs | |
| if self.disable_padding | |
| else cuda_graph_bs <= self.max_bs | |
| ) | |
| if self.require_mlp_sync: | |
| is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph | |
| return is_bs_supported | |
| def capture(self): | |
| CudaGraphRunner.capture(self) | |
| def capture_one_batch_size(self, num_seqs: int, forward: Callable): | |
| graph = torch.cuda.CUDAGraph() | |
| stream = self.stream | |
| num_tokens = num_seqs * self.num_tokens_per_bs | |
| # Graph inputs | |
| req_pool_indices = self.req_pool_indices[:num_seqs] | |
| seq_lens = self.seq_lens[:num_seqs] | |
| seq_lens_cpu = self.seq_lens_cpu[:num_seqs] | |
| extend_seq_lens = self.extend_seq_lens[:num_seqs] | |
| extend_seq_lens_cpu = self.extend_seq_lens_cpu[:num_seqs] | |
| out_cache_loc = self.out_cache_loc[: num_tokens * self.speculative_num_steps] | |
| positions = self.positions[:num_tokens] | |
| mrope_positions = self.mrope_positions[:, :num_tokens] | |
| topk_p = self.topk_p[:num_seqs] | |
| topk_index = self.topk_index[:num_seqs] | |
| hidden_states = self.hidden_states[:num_seqs] | |
| if self.require_mlp_tp_gather: | |
| self.global_num_tokens_gpu.copy_( | |
| torch.tensor( | |
| [num_tokens] * self.dp_size, | |
| dtype=torch.int32, | |
| device=self.input_ids.device, | |
| ) | |
| ) | |
| self.global_num_tokens_for_logprob_gpu.copy_( | |
| torch.tensor( | |
| [num_tokens] * self.dp_size, | |
| dtype=torch.int32, | |
| device=self.input_ids.device, | |
| ) | |
| ) | |
| global_num_tokens = self.global_num_tokens_gpu | |
| global_dp_buffer_len = num_tokens * self.dp_size | |
| global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu | |
| elif self.require_attn_tp_gather: | |
| self.global_num_tokens_gpu.copy_( | |
| torch.tensor( | |
| [num_tokens], | |
| dtype=torch.int32, | |
| device=self.input_ids.device, | |
| ) | |
| ) | |
| self.global_num_tokens_for_logprob_gpu.copy_( | |
| torch.tensor( | |
| [num_tokens], | |
| dtype=torch.int32, | |
| device=self.input_ids.device, | |
| ) | |
| ) | |
| global_num_tokens = self.global_num_tokens_gpu | |
| global_dp_buffer_len = num_tokens | |
| global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu | |
| else: | |
| global_num_tokens = None | |
| global_dp_buffer_len = None | |
| global_num_tokens_for_logprob = None | |
| spec_info = EagleDraftInput( | |
| topk_p=topk_p, | |
| topk_index=topk_index, | |
| hidden_states=hidden_states, | |
| capture_hidden_mode=CaptureHiddenMode.LAST, | |
| ) | |
| # Forward batch | |
| forward_batch = ForwardBatch( | |
| forward_mode=ForwardMode.DECODE, | |
| batch_size=num_seqs, | |
| input_ids=None, | |
| req_pool_indices=req_pool_indices, | |
| seq_lens=seq_lens, | |
| seq_lens_cpu=seq_lens_cpu, | |
| extend_seq_lens=extend_seq_lens, | |
| extend_seq_lens_cpu=extend_seq_lens_cpu, | |
| req_to_token_pool=self.model_runner.req_to_token_pool, | |
| token_to_kv_pool=self.model_runner.token_to_kv_pool, | |
| out_cache_loc=out_cache_loc, | |
| seq_lens_sum=seq_lens.sum().item(), | |
| return_logprob=False, | |
| positions=positions, | |
| mrope_positions=mrope_positions, | |
| global_num_tokens_gpu=global_num_tokens, | |
| dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(), | |
| global_dp_buffer_len=global_dp_buffer_len, | |
| spec_algorithm=self.model_runner.spec_algorithm, | |
| spec_info=spec_info, | |
| capture_hidden_mode=( | |
| spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL | |
| ), | |
| global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob, | |
| ) | |
| # Attention backend | |
| self.model_runner.draft_attn_backend.init_forward_metadata_capture_cuda_graph( | |
| forward_batch | |
| ) | |
| # Run and capture | |
| def run_once(): | |
| # Clean intermediate result cache for DP attention | |
| forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None | |
| set_dp_buffer_len(global_dp_buffer_len, num_tokens) | |
| set_is_extend_in_batch(False) | |
| # Backup two fields, which will be modified in-place in `draft_forward`. | |
| output_cache_loc_backup = forward_batch.out_cache_loc | |
| hidden_states_backup = forward_batch.spec_info.hidden_states | |
| ret = self.eagle_worker.draft_forward(forward_batch) | |
| forward_batch.out_cache_loc = output_cache_loc_backup | |
| forward_batch.spec_info.hidden_states = hidden_states_backup | |
| return ret | |
| self.deepep_adapter.capture(is_extend_in_batch=False) | |
| for _ in range(2): | |
| torch.cuda.synchronize() | |
| self.model_runner.tp_group.barrier() | |
| run_once() | |
| with torch.cuda.graph( | |
| graph, pool=get_global_graph_memory_pool(), stream=stream | |
| ): | |
| out = run_once() | |
| set_global_graph_memory_pool(graph.pool()) | |
| return graph, out | |
| def _postprocess_output_to_raw_bs(self, out, raw_bs): | |
| # Keep the variables name for readability | |
| parent_list, top_scores_index, draft_tokens = (t[:raw_bs] for t in out) | |
| return parent_list, top_scores_index, draft_tokens | |
| def replay(self, forward_batch: ForwardBatch): | |
| assert forward_batch.out_cache_loc is not None | |
| self.deepep_adapter.replay() | |
| raw_bs = forward_batch.batch_size | |
| raw_num_token = raw_bs * self.num_tokens_per_bs | |
| # Pad | |
| if self.require_mlp_tp_gather: | |
| max_num_tokens = max(forward_batch.global_num_tokens_cpu) | |
| max_batch_size = ( | |
| max_num_tokens // self.num_tokens_per_bs | |
| if self.model_runner.spec_algorithm.is_eagle() | |
| else max_num_tokens | |
| ) | |
| index = bisect.bisect_left(self.capture_bs, max_batch_size) | |
| else: | |
| index = bisect.bisect_left(self.capture_bs, raw_bs) | |
| bs = self.capture_bs[index] | |
| if bs != raw_bs: | |
| self.seq_lens.fill_(self.seq_len_fill_value) | |
| self.out_cache_loc.zero_() | |
| self.positions.zero_() | |
| num_tokens = bs * self.num_tokens_per_bs | |
| # Common inputs | |
| self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) | |
| self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) | |
| self.out_cache_loc[: raw_num_token * self.speculative_num_steps].copy_( | |
| forward_batch.out_cache_loc | |
| ) | |
| self.positions[:raw_num_token].copy_(forward_batch.positions) | |
| self.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p) | |
| self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index) | |
| self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states) | |
| # TODO(ch-wan): support num_token_non_padded | |
| if self.require_gathered_buffer: | |
| self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs) | |
| self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs) | |
| # Attention backend | |
| if bs != raw_bs: | |
| forward_batch.batch_size = bs | |
| forward_batch.seq_lens = self.seq_lens[:bs] | |
| forward_batch.req_pool_indices = self.req_pool_indices[:bs] | |
| forward_batch.positions = self.positions[:num_tokens] | |
| if forward_batch.seq_lens_cpu is not None: | |
| if bs != raw_bs: | |
| self.seq_lens_cpu.fill_(self.seq_len_fill_value) | |
| self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) | |
| forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs] | |
| self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph( | |
| forward_batch, bs | |
| ) | |
| # TODO: The forward_batch.seq_len_sum might need to be updated to reflect the padding in the cuda graph | |
| # Replay | |
| self.graphs[bs].replay() | |
| out = self.output_buffers[bs] | |
| if bs != raw_bs: | |
| out = self._postprocess_output_to_raw_bs(out, raw_bs) | |
| forward_batch.batch_size = raw_bs | |
| forward_batch.positions = self.positions[:raw_num_token] | |
| forward_batch.seq_lens = self.seq_lens[:raw_bs] | |
| forward_batch.req_pool_indices = self.req_pool_indices[:raw_bs] | |
| if forward_batch.seq_lens_cpu is not None: | |
| forward_batch.seq_lens_cpu = self.seq_lens_cpu[:raw_bs] | |
| return out | |
Xet Storage Details
- Size:
- 15 kB
- Xet hash:
- f37371996bd6f87d25941a2b98991c61e5db9a2be33e7e1baef2ec3776f49d18
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.