diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index aa10cb08d..d41c31a09 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -268,6 +268,12 @@ class ModelConfig: ): self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN" + if ( + is_draft_model + and self.hf_config.architectures[0] == "DeepseekV32ForCausalLM" + ): + self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN" + if is_draft_model and self.hf_config.architectures[0] == "Glm4MoeForCausalLM": self.hf_config.architectures[0] = "Glm4MoeForCausalLMNextN" diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 51af67636..3ec1778ed 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -21,6 +21,7 @@ Life cycle of a request in the decode server from __future__ import annotations import logging +import os import time from collections import deque from dataclasses import dataclass @@ -315,6 +316,16 @@ class DecodePreallocQueue: ) return kv_manager + def release_memory_occupation(self): + self.queue.clear() + self.retracted_queue.clear() + if hasattr(self.kv_manager, "deregister_buffer_to_engine"): + self.kv_manager.deregister_buffer_to_engine() + + def resume_memory_occupation(self): + if hasattr(self.kv_manager, "register_buffer_to_engine"): + self.kv_manager.register_buffer_to_engine() + def add(self, req: Req, is_retracted: bool = False) -> None: """Add a request to the pending queue.""" if self._check_if_req_exceed_kv_capacity(req): @@ -419,12 +430,37 @@ class DecodePreallocQueue: [decode_req.kv_receiver for decode_req in self.queue], self.gloo_group ) + # Bootstrap timeout: if a request has been stuck in Bootstrapping for too long, treat it as failed. + bootstrap_timeout = float( + os.environ.get("SGLANG_DISAGGREGATION_TRANSFER_TIMEOUT", "600") + ) + now = time.perf_counter() + for i, (decode_req, poll) in enumerate(zip(self.queue, polls)): if rids_to_check is not None and decode_req.req.rid not in rids_to_check: continue if poll == KVPoll.Bootstrapping: - pass + # Check for bootstrap timeout + entry_time = getattr( + decode_req.req.time_stats, + "decode_prealloc_queue_entry_time", + None, + ) + if entry_time is not None and (now - entry_time) > bootstrap_timeout: + error_message = ( + f"Decode bootstrap timed out after {now - entry_time:.1f}s " + f"for request rank={self.tp_rank} " + f"{decode_req.req.rid=} {decode_req.req.bootstrap_room=}" + ) + logger.error(error_message) + prepare_abort( + decode_req.req, + error_message, + status_code=HTTPStatus.GATEWAY_TIMEOUT, + ) + if self.scheduler.enable_metrics: + self.scheduler.metrics_collector.increment_bootstrap_failed_reqs() elif poll == KVPoll.WaitingForInput: decode_req.waiting_for_input = True elif poll == KVPoll.Failed: @@ -776,6 +812,13 @@ class DecodeTransferQueue: [decode_req.kv_receiver for decode_req in self.queue], self.gloo_group ) + # Transfer timeout: if a request has been in the transfer queue for too long + # (e.g., stuck in Bootstrapping/WaitingForInput/Transferring), treat it as failed. + transfer_timeout = float( + os.environ.get("SGLANG_DISAGGREGATION_TRANSFER_TIMEOUT", "600") + ) + now = time.perf_counter() + transferred_reqs = [] indices_to_remove = set() for i, (decode_req, poll) in enumerate(zip(self.queue, polls)): @@ -811,7 +854,33 @@ class DecodeTransferQueue: KVPoll.WaitingForInput, KVPoll.Transferring, ]: - pass + # Check for transfer timeout + entry_time = getattr( + decode_req.req.time_stats, + "decode_transfer_queue_entry_time", + None, + ) + if entry_time is not None and (now - entry_time) > transfer_timeout: + error_message = ( + f"Decode transfer timed out after {now - entry_time:.1f}s " + f"(state={poll}) for request rank={self.tp_rank} " + f"{decode_req.req.rid=} {decode_req.req.bootstrap_room=}" + ) + logger.error(error_message) + prepare_abort( + decode_req.req, + error_message, + status_code=HTTPStatus.GATEWAY_TIMEOUT, + ) + self.scheduler.stream_output( + [decode_req.req], decode_req.req.return_logprob + ) + release_kv_cache( + decode_req.req, self.tree_cache, is_insert=False + ) + indices_to_remove.add(i) + if self.scheduler.enable_metrics: + self.scheduler.metrics_collector.increment_transfer_failed_reqs() else: raise ValueError(f"Unexpected poll case: {poll}") @@ -827,6 +896,14 @@ class DecodeTransferQueue: return transferred_reqs + def release_memory_occupation(self): + """Clean up all in-flight transfers before releasing GPU memory.""" + self.queue.clear() + + def resume_memory_occupation(self): + """Resume after GPU memory re-allocation. Queue was already cleared on release.""" + pass + class SchedulerDisaggregationDecodeMixin: @@ -1004,7 +1081,15 @@ class SchedulerDisaggregationDecodeMixin: resumed_reqs = self.disagg_decode_prealloc_queue.resume_retracted_reqs() self.waiting_queue.extend(resumed_reqs) if len(self.disagg_decode_prealloc_queue.retracted_queue) > 0: - # if there are still retracted requests, we do not allocate new requests + # Still have retracted requests that couldn't resume (not enough memory). + # Don't accept new requests (pop_preallocated) — they would consume memory + # that retracted requests need. + # But DO drain completed transfers: their KV is already committed, and + # moving them to waiting_queue frees the reserved-decode-token budget + # in _allocatable_tokens(), which may unblock resume on the next iteration. + # Without this, completed transfers hold memory indefinitely → deadlock. + alloc_reqs = self.disagg_decode_transfer_queue.pop_transferred() + self.waiting_queue.extend(alloc_reqs) return if not hasattr(self, "polling_count"): diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 32e8c0b69..dc93c5c5f 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -253,6 +253,19 @@ class MooncakeKVManager(CommonKVManager): self.kv_args.state_data_ptrs, self.kv_args.state_data_lens ) + def deregister_buffer_to_engine(self): + # Batch deregister KV data buffers + if self.kv_args.kv_data_ptrs: + self.engine.batch_deregister(self.kv_args.kv_data_ptrs) + + # Batch deregister auxiliary data buffers + if self.kv_args.aux_data_ptrs: + self.engine.batch_deregister(self.kv_args.aux_data_ptrs) + + # Batch deregister state/extra pool data buffers + if self.kv_args.state_data_ptrs: + self.engine.batch_deregister(self.kv_args.state_data_ptrs) + def _transfer_data(self, mooncake_session_id, transfer_blocks): if not transfer_blocks: return 0 diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index a6eed743a..191b0977f 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -20,6 +20,7 @@ Life cycle of a request in the prefill server from __future__ import annotations import logging +import os import time from collections import deque from http import HTTPStatus @@ -250,6 +251,12 @@ class PrefillBootstrapQueue: [req.disagg_kv_sender for req in self.queue], self.gloo_group ) + # Bootstrap timeout: if a request has been stuck in Bootstrapping for too long, treat it as failed. + bootstrap_timeout = float( + os.environ.get("SGLANG_DISAGGREGATION_TRANSFER_TIMEOUT", "600") + ) + now = time.perf_counter() + for i, (req, poll) in enumerate(zip(self.queue, polls)): if rids_to_check is not None: # if req not in reqs_info_to_check, skip @@ -257,6 +264,27 @@ class PrefillBootstrapQueue: continue if poll == KVPoll.Bootstrapping: + # Check for bootstrap timeout + entry_time = getattr( + req.time_stats, + "prefill_bootstrap_queue_entry_time", + None, + ) + if entry_time is not None and (now - entry_time) > bootstrap_timeout: + error_message = ( + f"Prefill bootstrap timed out after {now - entry_time:.1f}s " + f"for request rank={self.tp_rank} " + f"{req.rid=} {req.bootstrap_room=}" + ) + logger.error(error_message) + prepare_abort( + req, error_message, status_code=HTTPStatus.GATEWAY_TIMEOUT + ) + self.scheduler.stream_output([req], req.return_logprob) + indices_to_remove.add(i) + failed_reqs.append(req) + if self.scheduler.enable_metrics: + self.scheduler.metrics_collector.increment_bootstrap_failed_reqs() continue elif poll == KVPoll.Failed: error_message = f"Prefill bootstrap failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}" @@ -306,6 +334,15 @@ class PrefillBootstrapQueue: else: return bootstrapped_reqs, failed_reqs + def release_memory_occupation(self): + self.queue.clear() + if hasattr(self.kv_manager, "deregister_buffer_to_engine"): + self.kv_manager.deregister_buffer_to_engine() + + def resume_memory_occupation(self): + if hasattr(self.kv_manager, "register_buffer_to_engine"): + self.kv_manager.register_buffer_to_engine() + class SchedulerDisaggregationPrefillMixin: """ @@ -535,6 +572,13 @@ class SchedulerDisaggregationPrefillMixin: self.attn_tp_cpu_group, ) + # Transfer timeout: if a request has been in the inflight queue for too long + # (e.g., stuck in WaitingForInput/Transferring), treat it as failed. + transfer_timeout = float( + os.environ.get("SGLANG_DISAGGREGATION_TRANSFER_TIMEOUT", "600") + ) + now = time.perf_counter() + undone_reqs: List[Req] = [] # Check .poll() for the reqs in disagg_prefill_inflight_queue. If Success, respond to the client and remove it from the queue for req, poll in zip(self.disagg_prefill_inflight_queue, polls): @@ -547,7 +591,30 @@ class SchedulerDisaggregationPrefillMixin: assert poll == KVPoll.Success or poll == KVPoll.Failed if poll in [KVPoll.WaitingForInput, KVPoll.Transferring]: - undone_reqs.append(req) + # Check for transfer timeout + entry_time = getattr( + req.time_stats, + "prefill_transfer_queue_entry_time", + None, + ) + if entry_time is not None and (now - entry_time) > transfer_timeout: + error_message = ( + f"Prefill transfer timed out after {now - entry_time:.1f}s " + f"(state={poll}) for request rank={self.tp_rank} " + f"{req.rid=} {req.bootstrap_room=}" + ) + logger.error(error_message) + release_kv_cache(req, self.tree_cache) # unlock the tree + prepare_abort( + req, error_message, status_code=HTTPStatus.GATEWAY_TIMEOUT + ) + if hasattr(req.disagg_kv_sender, "clear"): + req.disagg_kv_sender.clear() + done_reqs.append(req) + if self.enable_metrics: + self.metrics_collector.increment_transfer_failed_reqs() + else: + undone_reqs.append(req) elif poll == KVPoll.Success: # transfer done release_kv_cache(req, self.tree_cache) # unlock the tree req.finished_reason = FINISH_LENGTH(length=0) diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 0478526ef..cfb1aa669 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -1797,7 +1797,10 @@ def get_tensor_model_parallel_world_size(): def get_tensor_model_parallel_rank(): """Return my rank for the tensor model parallel group.""" - return get_tp_group().rank_in_group + try: + return get_tp_group().rank_in_group + except Exception: + return 0 def get_pipeline_model_parallel_world_size(): diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 6f69fd19b..da20ac2ed 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -49,6 +49,7 @@ from sglang.srt.managers.io_struct import ( InitWeightsUpdateGroupReqInput, LoadLoRAAdapterReqInput, MultimodalDataInputFormat, + PostProcessWeightsReqInput, ReleaseMemoryOccupationReqInput, ResumeMemoryOccupationReqInput, RpcReqInput, @@ -593,6 +594,24 @@ class Engine(EngineBase): self.tokenizer_manager.update_weights_from_ipc(obj, None) ) + def post_process_weights( + self, + restore_weights_before_load: bool = False, + post_process_quantization: bool = False, + ): + """ + Optional post-processing for updated weights (e.g., Marlin conversion). + Should be called after weight update is finished. + """ + obj = PostProcessWeightsReqInput( + restore_weights_before_load=restore_weights_before_load, + post_process_quantization=post_process_quantization, + ) + + return self.loop.run_until_complete( + self.tokenizer_manager.post_process_weights(obj, None) + ) + def get_weights_by_name(self, name: str, truncate_size: int = 100): """Get weights by parameter name.""" obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 88705cc35..c8dc052f1 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -107,6 +107,7 @@ from sglang.srt.managers.io_struct import ( OpenSessionReqInput, ParseFunctionCallReq, PauseGenerationReqInput, + PostProcessWeightsReqInput, ProfileReqInput, ReleaseMemoryOccupationReqInput, ResumeMemoryOccupationReqInput, @@ -957,6 +958,21 @@ async def update_weights_from_ipc(obj: UpdateWeightsFromIPCReqInput, request: Re else: return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) +@app.post("/post_process_weights") +async def post_process_weights(req: PostProcessWeightsReqInput, request: Request): + """ + Optional post-processing for updated weights (e.g., Marlin conversion). + This should be called selectively after `update_weights_from_distributed/update_weights_from_tensor`. + """ + success, message = await _global_state.tokenizer_manager.post_process_weights( + req, request + ) + + content = {"success": success, "message": message} + return ORJSONResponse( + content, status_code=200 if success else HTTPStatus.BAD_REQUEST + ) + @app.post("/update_weight_version") async def update_weight_version(obj: UpdateWeightVersionReqInput, request: Request): diff --git a/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py b/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py index d6c499df0..565004260 100644 --- a/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py +++ b/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py @@ -613,7 +613,6 @@ def _get_k_and_s_triton( page_indices, k_out, s_out, - seq_len, page_size, buf_numel_per_page, index_head_dim, @@ -630,7 +629,6 @@ def _get_k_and_s_triton_kernel( page_indices_ptr, k_out_ptr, s_out_ptr, - seq_len: tl.constexpr, page_size: tl.constexpr, buf_numel_per_page: tl.constexpr, index_head_dim: tl.constexpr, diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py index c9e82e4b1..f2584546a 100644 --- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py @@ -3,6 +3,7 @@ from __future__ import annotations from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +import os import torch from einops import rearrange @@ -178,7 +179,7 @@ class Indexer(MultiPlatformOp): max_position=max_position_embeddings, base=rope_theta, # type: ignore rope_scaling=rope_scaling, - is_neox_style=True, + is_neox_style=True if os.environ.get("INDEXER_ROPE_NEOX_STYLE", "1") == "1" else False, device=get_global_server_args().device, ) self.block_size = block_size @@ -188,6 +189,9 @@ class Indexer(MultiPlatformOp): @torch.compile(dynamic=True) def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor): weights, _ = self.weights_proj(x.float()) + if weights.shape[1] < 32: + assert 32 % weights.shape[1] == 0 + weights = weights.repeat_interleave(32 // weights.shape[1], dim=1) weights = weights * self.n_heads**-0.5 weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale return weights @@ -837,6 +841,9 @@ class Indexer(MultiPlatformOp): query, key = self._get_q_k_bf16( q_lora, x, positions, enable_dual_stream, forward_batch=forward_batch ) + if query.shape[1] < 32: + assert 32 % query.shape[1] == 0 + query = query.repeat_interleave(32//query.shape[1], dim=1) if enable_dual_stream: current_stream = torch.cuda.current_stream() diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 15df851eb..1636ed706 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -371,10 +371,13 @@ class LayerCommunicator: residual: torch.Tensor, forward_batch: ForwardBatch, captured_last_layer_outputs: Optional[List[torch.Tensor]] = None, - **kwargs, + post_residual_addition: Optional[torch.Tensor] = None, ): hidden_states, residual = self.prepare_attn( - hidden_states, residual, forward_batch, **kwargs + hidden_states, + residual, + forward_batch, + post_residual_addition=post_residual_addition, ) if captured_last_layer_outputs is not None: gathered_last_layer_output = self._communicate_simple_fn( @@ -394,7 +397,7 @@ class LayerCommunicator: residual: torch.Tensor, forward_batch: ForwardBatch, quant_format: str = "", - **kwargs, + post_residual_addition: Optional[torch.Tensor] = None, ): if get_attn_tp_context().input_scattered: hidden_states, residual = self._tp_reduce_scatter( @@ -444,7 +447,7 @@ class LayerCommunicator: ) else: - hidden_states = self.input_layernorm(hidden_states, **kwargs) + hidden_states = self.input_layernorm(hidden_states) else: if _use_aiter and _is_gfx95_supported and ("mxfp4" in quant_format): @@ -478,7 +481,7 @@ class LayerCommunicator: hidden_states, residual = self.input_layernorm( hidden_states, residual, - **kwargs, + post_residual_addition, ) hidden_states = self._communicate_simple_fn( diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 7bef9d2ab..5926ff7f5 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -83,15 +83,12 @@ class RMSNorm(MultiPlatformOp): eps: float = 1e-6, var_hidden_size: Optional[int] = None, cast_x_before_out_mul: bool = False, - fp32_residual: bool = False, - weight_dtype: Optional = None, - override_orig_dtype: Optional = None, + fp32_residual: bool = True, ) -> None: super().__init__() self.cast_x_before_out_mul = cast_x_before_out_mul self.fp32_residual = fp32_residual - self.override_orig_dtype = override_orig_dtype - self.weight = nn.Parameter(torch.ones(hidden_size, dtype=weight_dtype)) + self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps self.hidden_size = hidden_size self.variance_size_override = ( @@ -104,16 +101,16 @@ class RMSNorm(MultiPlatformOp): self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, - **kwargs, + post_residual_addition: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if self.variance_size_override is not None: - return self.forward_native(x, residual, **kwargs) + return self.forward_native(x, residual, post_residual_addition) if is_batch_invariant_mode_enabled(): if ( residual is not None or get_global_server_args().rl_on_policy_target == "fsdp" ): - return self.forward_native(x, residual, **kwargs) + return self.forward_native(x, residual, post_residual_addition) return rms_norm_batch_invariant( x, self.weight.data, @@ -124,7 +121,6 @@ class RMSNorm(MultiPlatformOp): # but right now we can only have hidden_states+(residual+post_residual_addition). # (hidden_states+residual)+post_residual_addition != hidden_states+(residual+post_residual_addition), # we probably need to add another parameter to fused_add_rmsnorm - post_residual_addition = kwargs.get("post_residual_addition") if post_residual_addition is not None: residual = residual + post_residual_addition fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon) @@ -136,9 +132,11 @@ class RMSNorm(MultiPlatformOp): self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, - **kwargs, + post_residual_addition: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if residual is not None: + if post_residual_addition is not None: + residual = residual + post_residual_addition out, _, residual_out = torch_npu.npu_add_rms_norm( residual, x, self.weight.data, self.variance_epsilon ) @@ -149,9 +147,11 @@ class RMSNorm(MultiPlatformOp): self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, - **kwargs, + post_residual_addition: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if residual is not None: + if post_residual_addition is not None: + residual = residual + post_residual_addition residual_out = torch.empty_like(x) output = torch.empty_like(x) fused_add_rms_norm( @@ -169,12 +169,14 @@ class RMSNorm(MultiPlatformOp): self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, - **kwargs, + post_residual_addition: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if not x.is_contiguous(): # NOTE: Remove this if aiter kernel supports discontinuous input x = x.contiguous() if residual is not None: + if post_residual_addition is not None: + residual = residual + post_residual_addition out = torch.empty_like(x) residual_out = torch.empty_like(x) fused_add_rms_norm( @@ -189,27 +191,23 @@ class RMSNorm(MultiPlatformOp): self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, - **kwargs, + post_residual_addition: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if not x.is_contiguous(): x = x.contiguous() - orig_dtype = self.override_orig_dtype or x.dtype - post_residual_addition = kwargs.get("post_residual_addition") + orig_dtype = x.dtype + + if residual is not None and not self.fp32_residual: + x = x + residual + if post_residual_addition is not None: + x = x + post_residual_addition + residual = x.clone() x = x.to(torch.float32) - if residual is not None: - x = ( - x - + residual.to(torch.float32) - + ( - post_residual_addition.to(torch.float32) - if post_residual_addition is not None - else 0.0 - ) - ) - if self.fp32_residual: - residual = x.clone() - else: - residual = x.to(orig_dtype) + if residual is not None and self.fp32_residual: + x = x + residual.to(torch.float32) + if post_residual_addition is not None: + x = x + post_residual_addition.to(torch.float32) + residual = x.to(orig_dtype) hidden_size = x.shape[-1] if hidden_size != self.hidden_size: @@ -246,10 +244,12 @@ class RMSNorm(MultiPlatformOp): self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, - **kwargs, + post_residual_addition: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if _is_cpu_amx_available: if residual is not None: + if post_residual_addition is not None: + residual = residual + post_residual_addition torch.ops.sgl_kernel.fused_add_rmsnorm_cpu( x, residual, self.weight.data, self.variance_epsilon ) @@ -258,17 +258,19 @@ class RMSNorm(MultiPlatformOp): x, self.weight.data, self.variance_epsilon ) else: - return self.forward_native(x, residual, **kwargs) + return self.forward_native(x, residual, post_residual_addition) def forward_xpu( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, - **kwargs, + post_residual_addition: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if self.variance_size_override is not None: - return self.forward_native(x, residual, **kwargs) + return self.forward_native(x, residual, post_residual_addition) if residual is not None: + if post_residual_addition is not None: + residual = residual + post_residual_addition fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon) return x, residual out = rmsnorm(x, self.weight.data, self.variance_epsilon) @@ -278,6 +280,7 @@ class RMSNorm(MultiPlatformOp): self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, + post_residual_addition: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Forward method with allreduce fusion, prioritizing flashinfer fused operations @@ -289,6 +292,8 @@ class RMSNorm(MultiPlatformOp): ) if get_tensor_model_parallel_world_size() > 1: + if post_residual_addition is not None: + x = x + post_residual_addition fused_result = flashinfer_allreduce_residual_rmsnorm( input_tensor=x, residual=residual, @@ -298,7 +303,7 @@ class RMSNorm(MultiPlatformOp): if fused_result[0] is not None: return fused_result - return self.forward(x, residual) + return self.forward(x, residual, post_residual_addition) class LayerNorm(MultiPlatformOp): @@ -323,7 +328,6 @@ class LayerNorm(MultiPlatformOp): def forward_cuda( self, x: torch.Tensor, - **kwargs, ) -> torch.Tensor: if ( _flashinfer_layernorm_available @@ -332,12 +336,11 @@ class LayerNorm(MultiPlatformOp): ): return layernorm(x, self.weight, self.bias, self.variance_epsilon) else: - return self.forward_native(x, **kwargs) + return self.forward_native(x) def forward_native( self, x: torch.Tensor, - **kwargs, ) -> torch.Tensor: weight = self.weight if self.elementwise_affine else None bias = self.bias if self.use_bias else None @@ -354,28 +357,25 @@ class LayerNorm(MultiPlatformOp): def forward_hip( self, x: torch.Tensor, - **kwargs, ) -> torch.Tensor: - return self.forward_native(x, **kwargs) + return self.forward_native(x) def forward_npu( self, x: torch.Tensor, - **kwargs, ) -> torch.Tensor: - return self.forward_native(x, **kwargs) + return self.forward_native(x) def forward_cpu( self, x: torch.Tensor, - **kwargs, ) -> torch.Tensor: if _is_cpu_amx_available: return torch.ops.sgl_kernel.layernorm_cpu( x, self.weight.data, self.variance_epsilon ) else: - return self.forward_native(x, **kwargs) + return self.forward_native(x) class GemmaRMSNorm(MultiPlatformOp): @@ -396,9 +396,11 @@ class GemmaRMSNorm(MultiPlatformOp): self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, - **kwargs, + post_residual_addition: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if residual is not None: + if post_residual_addition is not None: + residual = residual + post_residual_addition gemma_fused_add_rmsnorm( x, residual, self.weight.data, self.variance_epsilon ) @@ -410,11 +412,13 @@ class GemmaRMSNorm(MultiPlatformOp): self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, - **kwargs, + post_residual_addition: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: orig_dtype = x.dtype if residual is not None: x = x + residual + if post_residual_addition is not None: + x = x + post_residual_addition residual = x x = x.float() @@ -428,18 +432,20 @@ class GemmaRMSNorm(MultiPlatformOp): self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, - **kwargs, + post_residual_addition: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - return self._forward_impl(x, residual, **kwargs) + return self._forward_impl(x, residual, post_residual_addition) def forward_cpu( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, - **kwargs, + post_residual_addition: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if _is_cpu_amx_available: if residual is not None: + if post_residual_addition is not None: + residual = residual + post_residual_addition torch.ops.sgl_kernel.gemma_fused_add_rmsnorm_cpu( x, residual, self.weight.data, self.variance_epsilon ) @@ -447,16 +453,18 @@ class GemmaRMSNorm(MultiPlatformOp): return torch.ops.sgl_kernel.gemma_rmsnorm_cpu( x, self.weight.data, self.variance_epsilon ) - return self.forward_native(x, residual, **kwargs) + return self.forward_native(x, residual, post_residual_addition) def forward_npu( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, - **kwargs, + post_residual_addition: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if residual is not None: x = x + residual + if post_residual_addition is not None: + x = x + post_residual_addition residual = x x, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon) @@ -466,9 +474,9 @@ class GemmaRMSNorm(MultiPlatformOp): self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, - **kwargs, + post_residual_addition: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - return self._forward_impl(x, residual, **kwargs) + return self._forward_impl(x, residual, post_residual_addition) class Gemma3RMSNorm(MultiPlatformOp): @@ -481,22 +489,22 @@ class Gemma3RMSNorm(MultiPlatformOp): def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - def forward_native(self, x, **kwargs): + def forward_native(self, x): output = self._norm(x.float()) # Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16) # See https://github.com/huggingface/transformers/pull/29402 output = output * (1.0 + self.weight.float()) return output.type_as(x) - def forward_cpu(self, x, **kwargs): + def forward_cpu(self, x): if _is_cpu_amx_available and x.stride(-1) == 1: return torch.ops.sgl_kernel.gemma3_rmsnorm_cpu(x, self.weight, self.eps) - return self.forward_native(x, **kwargs) + return self.forward_native(x) - def forward_cuda(self, x, **kwargs): - return self.forward_native(x, **kwargs) + def forward_cuda(self, x): + return self.forward_native(x) - def forward_npu(self, x, **kwargs): + def forward_npu(self, x): output, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.eps) return output diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index fa7431048..cd33ea735 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -878,11 +878,6 @@ class LogitsProcessor(nn.Module): None, # bias True, # is_vnni ) - elif get_global_server_args().rl_on_policy_target is not None: - # Due to tie-weight, we may not be able to change lm_head's weight dtype - logits = torch.matmul( - hidden_states.bfloat16(), lm_head.weight.T.bfloat16() - ) else: logits = torch.matmul( hidden_states.to(lm_head.weight.dtype), lm_head.weight.T diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index a1885fade..14d692365 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -14,6 +14,7 @@ import torch.nn.functional as F import triton.language as tl from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import ( cpu_has_amx_support, get_bool_env_var, @@ -573,7 +574,10 @@ def fused_experts_impl( ).squeeze(dim=1) else: # According to micro benchmark results, torch.compile can get better performance for small token. - if tokens_in_chunk <= 32: + if ( + not get_global_server_args().enable_deterministic_inference + and tokens_in_chunk <= 32 + ): moe_sum_reduce_torch_compile( intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx], diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 839463518..7948779aa 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -647,7 +647,7 @@ class FusedMoE(torch.nn.Module): "CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod", ] - ) + ) and "zero" not in weight_name else loaded_weight ) diff --git a/python/sglang/srt/layers/moe/routed_experts_capturer.py b/python/sglang/srt/layers/moe/routed_experts_capturer.py index 00bd68755..5a3ca8a67 100644 --- a/python/sglang/srt/layers/moe/routed_experts_capturer.py +++ b/python/sglang/srt/layers/moe/routed_experts_capturer.py @@ -1,5 +1,6 @@ import logging from abc import ABC +from contextlib import contextmanager from typing import Optional import numpy as np @@ -8,13 +9,18 @@ import torch from sglang.srt.configs.model_config import ModelConfig from sglang.srt.layers.dp_attention import ( + attn_tp_all_gather_into_tensor, get_attention_dp_rank, + get_attention_tp_size, get_dp_local_info, is_dp_attention_enabled, ) from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.server_args import get_global_server_args +from sglang.srt.layers.moe import ( + get_moe_a2a_backend, +) logger = logging.getLogger(__name__) @@ -181,13 +187,26 @@ class _RoutedExpertsCapturerReal(RoutedExpertsCapturer): device=device, ) + if get_moe_a2a_backend().is_deepep(): + attn_tp_size = get_attention_tp_size() if is_dp_attention_enabled() else 1 + self.gather_buffer = torch.empty( + ( + self.device_cache.buffer.shape[0] * attn_tp_size, + self.device_cache.buffer.shape[2], + ), + dtype=torch.int32, + device=device, + ) + def _sync_fwd_experts_buffer_DtoH( self, forward_batch: ForwardBatch, can_run_graph: bool, cuda_graph_batch: int, ): - if is_dp_attention_enabled(): + # When DeepEP is enabled, capture() already does all_gather, so device_cache.buffer + # contains data from all DP ranks. We should not slice by DP rank in this case. + if is_dp_attention_enabled() and not get_moe_a2a_backend().is_deepep(): local_start_pos, local_num_tokens = get_dp_local_info(forward_batch) # handle with cuda graph padding if can_run_graph: @@ -206,6 +225,12 @@ class _RoutedExpertsCapturerReal(RoutedExpertsCapturer): ].cpu() def capture(self, layer_id: int, topk_ids: torch.Tensor): + if get_moe_a2a_backend().is_deepep(): + local_topk_ids = topk_ids + topk_ids = self.gather_buffer[ + : local_topk_ids.size(0) * get_attention_tp_size() + ] + attn_tp_all_gather_into_tensor(topk_ids, local_topk_ids) self.device_cache.capture_fwd_routed_experts(layer_id, topk_ids) def get_routed_experts( diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py index b4bdc41b3..3b895ff6a 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -442,7 +442,7 @@ class CompressedTensorsConfig(QuantizationConfig): ) is_static = not weight_quant.dynamic - return is_channel_group and input_quant_none and is_symmetric and is_static + return is_channel_group and input_quant_none and is_static def _get_scheme_from_parts( self, weight_quant: BaseModel, input_quant: BaseModel diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index c5e5a11fc..c46526ecc 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -30,7 +30,10 @@ from sglang.srt.layers.quantization.fp8_utils import ( normalize_e4m3fn_to_e4m3fnuz, ) from sglang.srt.layers.quantization.gptq import gptq_marlin_moe_repack -from sglang.srt.layers.quantization.marlin_utils import marlin_moe_permute_scales +from sglang.srt.layers.quantization.marlin_utils import ( + marlin_moe_permute_scales, + moe_awq_to_marlin_zero_points +) from sglang.srt.layers.quantization.utils import ( all_close_1d, per_tensor_dequantize, @@ -865,7 +868,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): self.strategy = config.strategy self.group_size = config.group_size self.actorder = config.actorder - assert config.symmetric, "Only symmetric quantization is supported for MoE" + self.sym = config.symmetric if not ( self.quant_config.quant_format == CompressionFormat.pack_quantized.value @@ -920,7 +923,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): # In the case where we have actorder/g_idx, # we do not partition the w2 scales - load_full_w2 = self.actorder and self.group_size != -1 + load_full_w2 = (self.actorder != 'static') and self.group_size != -1 if load_full_w2: w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size @@ -968,6 +971,32 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): layer.register_parameter("w13_weight_shape", w13_weight_shape) set_weight_attrs(w13_weight_shape, extra_weight_attrs) + # add zero param + if not self.sym: + w13_qzeros = torch.nn.Parameter( + torch.empty( + num_experts, + num_groups_w13, + 2 * intermediate_size_per_partition // self.packed_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_zero_point", w13_qzeros) + set_weight_attrs(w13_qzeros, extra_weight_attrs) + + w2_qzeros = torch.nn.Parameter( + torch.empty( + num_experts, + num_groups_w2, + hidden_size // self.packed_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_zero_point", w2_qzeros) + set_weight_attrs(w2_qzeros, extra_weight_attrs) + w13_g_idx = torch.nn.Parameter( torch.empty( num_experts, @@ -1016,13 +1045,40 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): layer.a2_scale = None layer.marlin_state = GPTQMarlinState.REPACK + if not hasattr(layer, "_original_shapes"): + layer._original_shapes = {} + + # Force record: these are the target GPTQ shapes for rollback. + layer._original_shapes["w13_weight_packed"] = tuple(w13_weight.shape) + layer._original_shapes["w13_weight_scale"] = tuple(w13_scale.shape) + if not self.sym: + layer._original_shapes["w13_weight_zero_point"] = w13_qzeros.shape + + layer._original_shapes["w2_weight_packed"] = tuple(w2_weight.shape) + layer._original_shapes["w2_weight_scale"] = tuple(w2_scale.shape) + if not self.sym: + layer._original_shapes["w2_weight_zero_point"] = tuple(w2_qzeros.shape) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Skip if the layer is already converted to Marlin format to prevent double-packing. + if getattr(layer, "is_marlin_converted", False): + return + + if not hasattr(layer, "_original_shapes"): + layer._original_shapes = {} def replace_tensor(name, new_t): + target_attr = getattr(layer, name) + + # Only save if the key doesn't exist to prevent overwriting with Marlin shapes. + if name not in layer._original_shapes: + # This is a safety check; `create_weights` usually handles this already. + layer._original_shapes[name] = tuple(target_attr.shape) + # It is important to use resize_() here since it ensures # the same buffer is reused - getattr(layer, name).resize_(new_t.shape) - getattr(layer, name).copy_(new_t) + target_attr.resize_(new_t.shape) + target_attr.copy_(new_t) del new_t num_experts = layer.w13_weight_g_idx.shape[0] @@ -1078,7 +1134,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): layer.w13_weight_packed.shape[2], self.num_bits, ) - replace_parameter(layer, "w13_weight_packed", marlin_w13_qweight) + replace_tensor("w13_weight_packed", marlin_w13_qweight) marlin_w2_qweight = gptq_marlin_moe_repack( layer.w2_weight_packed, layer.w2_g_idx_sort_indices, @@ -1086,7 +1142,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): layer.w2_weight_packed.shape[2], self.num_bits, ) - replace_parameter(layer, "w2_weight_packed", marlin_w2_qweight) + replace_tensor("w2_weight_packed", marlin_w2_qweight) # Repack scales marlin_w13_scales = marlin_moe_permute_scales( layer.w13_weight_scale, @@ -1094,7 +1150,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): layer.w13_weight_scale.shape[2], self.group_size, ) - replace_parameter(layer, "w13_weight_scale", marlin_w13_scales) + replace_tensor("w13_weight_scale", marlin_w13_scales) marlin_w2_scales = marlin_moe_permute_scales( layer.w2_weight_scale, @@ -1103,7 +1159,40 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): layer.w2_weight_scale.shape[2], self.group_size, ) - replace_parameter(layer, "w2_weight_scale", marlin_w2_scales) + replace_tensor("w2_weight_scale", marlin_w2_scales) + + # Repack zero + if not self.sym: + marlin_w13_zp = moe_awq_to_marlin_zero_points( + layer.w13_weight_zero_point, + size_k=layer.w13_weight_zero_point.shape[1], + size_n=layer.w13_weight_zero_point.shape[2] * self.packed_factor, + num_bits=self.num_bits, + ) + replace_tensor("w13_weight_zero_point", marlin_w13_zp) + + marlin_w2_zp = moe_awq_to_marlin_zero_points( + layer.w2_weight_zero_point, + size_k=layer.w2_weight_zero_point.shape[1], + size_n=layer.w2_weight_zero_point.shape[2] * self.packed_factor, + num_bits=self.num_bits, + ) + replace_tensor("w2_weight_zero_point", marlin_w2_zp) + + layer.is_marlin_converted = True + + def restore_weights_before_loading(self, layer: torch.nn.Module): + """Forcibly resize parameters back to their original shapes (e.g., GPTQ format) before loading weights.""" + if not hasattr(layer, "_original_shapes"): + return + + for name, orig_shape in layer._original_shapes.items(): + param = getattr(layer, name, None) + + if param is not None and param.shape != orig_shape: + param.resize_(orig_shape) + + layer.is_marlin_converted = False def create_moe_runner( self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig @@ -1154,6 +1243,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): g_idx2=layer.w2_weight_g_idx, sort_indices1=layer.w13_g_idx_sort_indices, sort_indices2=layer.w2_g_idx_sort_indices, + w1_zeros=layer.w13_weight_zero_point if not self.sym else None, + w2_zeros=layer.w2_weight_zero_point if not self.sym else None, num_bits=self.num_bits, is_k_full=self.is_k_full, routed_scaling_factor=self.moe_runner_config.routed_scaling_factor, diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 480579e01..dd8ca7d4f 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -136,9 +136,7 @@ class RotaryEmbedding(MultiPlatformOp): if get_global_server_args().rl_on_policy_target is not None: self._forward_method = self.forward_native - self._apply_rotary_emb_wrapped = torch.compile(dynamic=True)( - self._apply_rotary_emb_wrapped - ) + self.position_cos, self.position_sin = None, None def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: @@ -1578,6 +1576,9 @@ class MRotaryEmbedding(RotaryEmbedding): key: torch.Tensor, fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: + assert ( + fused_set_kv_buffer_arg is None + ), "fused_set_kv_buffer_arg is not supported for npu implementation" # TODO: remove this when npu_mrope supports QNumHeads * QHeadSize > 4096 assert ( fused_set_kv_buffer_arg is None diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 55bef5652..35ad68b1c 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -108,16 +108,11 @@ class Sampler(nn.Module): if return_logprob and SGLANG_RETURN_ORIGINAL_LOGPROB: probs_without_temp_scaling = torch.softmax(logits, dim=-1) - if get_global_server_args().rl_on_policy_target is not None: - logits_div_temperature = ( - logits.bfloat16().div(sampling_info.temperatures).bfloat16() - ) - logprobs_via_logsoftmax_kernel = torch.log_softmax( - logits_div_temperature, dim=-1 - ) - # Post process logits logits.div_(sampling_info.temperatures) + if get_global_server_args().rl_on_policy_target is not None: + logprobs_via_logsoftmax_kernel = torch.log_softmax(logits, dim=-1) + # For ascend backend, softmax is not needed before sampling if not get_global_server_args().sampling_backend == "ascend" or ( return_logprob and not SGLANG_RETURN_ORIGINAL_LOGPROB diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 2ecd8542f..2a2e011ea 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -1292,6 +1292,19 @@ class UpdateWeightsFromIPCReqOutput(BaseReq): success: bool message: str +@dataclass +class PostProcessWeightsReqInput(BaseReq): + # Whether to restore weights before loading new weights + restore_weights_before_load: bool = False + # Whether to enable quantization post-processing + post_process_quantization: bool = False + + +@dataclass +class PostProcessWeightsReqOutput(BaseReq): + success: bool + message: str + @dataclass class InitWeightsSendGroupForRemoteInstanceReqOutput(BaseReq): @@ -1667,6 +1680,10 @@ class GetLoadReqOutput(BaseReq): num_waiting_reqs: int num_tokens: int ts_tic: float + # Per-queue breakdown: list of {name, num_reqs, num_tokens, reqs: [{rid, seqlen, input_len, output_len}]} + queue_details: Optional[List[Dict[str, Any]]] = None + # Running batch info + running_details: Optional[Dict[str, Any]] = None @dataclass diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index d423e61d7..d1f54a832 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1779,7 +1779,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): selected_indices=sorted_indices, buf_multiplier=buf_multiplier ) ): - if len(sorted_indices) == 1: + # We should allow all requests to be retracted in decode disaggregation mode + # because there call be prealloc prefill requests. + num_minimum_reqs = 0 if server_args.disaggregation_mode == "decode" else 1 + if len(sorted_indices) == num_minimum_reqs: # Always keep at least one request break diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 92d286897..43bfab691 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -98,6 +98,7 @@ from sglang.srt.managers.io_struct import ( OpenSessionReqInput, OpenSessionReqOutput, PauseGenerationReqInput, + PostProcessWeightsReqInput, ProfileReq, ReleaseMemoryOccupationReqInput, ResumeMemoryOccupationReqInput, @@ -1060,6 +1061,7 @@ class Scheduler( ), (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor), (UpdateWeightsFromIPCReqInput, self.update_weights_from_ipc), + (PostProcessWeightsReqInput, self.post_process_weights), (GetWeightsByNameReqInput, self.get_weights_by_name), (ReleaseMemoryOccupationReqInput, self.release_memory_occupation), (ResumeMemoryOccupationReqInput, self.resume_memory_occupation), diff --git a/python/sglang/srt/managers/scheduler_metrics_mixin.py b/python/sglang/srt/managers/scheduler_metrics_mixin.py index d44ff6027..3fad54598 100644 --- a/python/sglang/srt/managers/scheduler_metrics_mixin.py +++ b/python/sglang/srt/managers/scheduler_metrics_mixin.py @@ -553,12 +553,48 @@ class SchedulerMetricsMixin: num_tokens += sum(req.seqlen for queue in waiting_queues for req in queue) num_waiting_reqs = sum(len(queue) for queue in waiting_queues) + # Collect per-queue details + queue_names = ["waiting_queue"] + if self.disaggregation_mode == DisaggregationMode.PREFILL: + queue_names.append("bootstrap_queue") + elif self.disaggregation_mode == DisaggregationMode.DECODE: + queue_names.append("prealloc_queue") + queue_names.append("transfer_queue") + queue_names.append("retracted_queue") + + queue_details = [] + for name, queue in zip(queue_names, waiting_queues): + reqs_info = [] + for req in queue: + reqs_info.append({ + "seqlen": req.seqlen, + }) + queue_details.append({ + "name": name, + "num_reqs": len(queue), + "num_tokens": sum(r["seqlen"] for r in reqs_info), + "reqs": reqs_info, + }) + + # Collect running batch details + running_reqs_info = [] + for req in self.running_batch.reqs: + running_reqs_info.append({ + "seqlen": req.seqlen, + }) + running_details = { + "num_reqs": len(self.running_batch.reqs), + "reqs": running_reqs_info, + } + return GetLoadReqOutput( dp_rank=self.dp_rank, num_reqs=len(self.running_batch.reqs) + num_waiting_reqs, num_waiting_reqs=num_waiting_reqs, num_tokens=num_tokens, ts_tic=time.perf_counter(), + queue_details=queue_details, + running_details=running_details, ) @contextmanager diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index e40586c24..243e2b0c2 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -10,6 +10,7 @@ from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.environ import envs from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.moe.routed_experts_capturer import get_global_experts_capturer + from sglang.srt.managers.io_struct import ( AbortReq, BatchEmbeddingOutput, @@ -1070,7 +1071,7 @@ class SchedulerOutputProcessorMixin: req.log_time_stats() # Send to detokenizer - if reqs or is_idle_batch: + if rids or is_idle_batch: if self.model_config.is_multimodal_gen: return diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py index 293a84350..8ee36c794 100644 --- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py +++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import os import traceback from typing import TYPE_CHECKING, Tuple @@ -12,6 +13,9 @@ from sglang.srt.constants import ( GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS, ) +from sglang.srt.disaggregation.utils import DisaggregationMode +from sglang.srt.distributed import get_moe_ep_group, get_moe_tp_group, get_tp_group +from sglang.srt.layers.dp_attention import get_attention_tp_group from sglang.srt.managers.io_struct import ( CheckWeightsReqInput, CheckWeightsReqOutput, @@ -21,6 +25,8 @@ from sglang.srt.managers.io_struct import ( GetWeightsByNameReqOutput, InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqOutput, + PostProcessWeightsReqInput, + PostProcessWeightsReqOutput, ReleaseMemoryOccupationReqInput, ReleaseMemoryOccupationReqOutput, ResumeMemoryOccupationReqInput, @@ -114,6 +120,11 @@ class SchedulerUpdateWeightsMixin: torch.distributed.barrier(group=self.tp_cpu_group) return UpdateWeightsFromIPCReqOutput(success, message) + def post_process_weights(self, recv_req: PostProcessWeightsReqInput): + """Optional post-processing for updated weights (e.g., Marlin conversion).""" + success, message = self.tp_worker.post_process_weights(recv_req) + return PostProcessWeightsReqOutput(success, message) + def get_weights_by_name(self: Scheduler, recv_req: GetWeightsByNameReqInput): parameter = self.tp_worker.get_weights_by_name(recv_req) return GetWeightsByNameReqOutput(parameter) @@ -137,6 +148,15 @@ class SchedulerUpdateWeightsMixin: self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE) self.flush_cache() + if self.disaggregation_mode == DisaggregationMode.DECODE: + if hasattr(self, "disagg_decode_transfer_queue"): + self.disagg_decode_transfer_queue.release_memory_occupation() + if hasattr(self, "disagg_decode_prealloc_queue"): + self.disagg_decode_prealloc_queue.release_memory_occupation() + elif self.disaggregation_mode == DisaggregationMode.PREFILL: + if hasattr(self, "disagg_prefill_bootstrap_queue"): + self.disagg_prefill_bootstrap_queue.release_memory_occupation() + if GPU_MEMORY_TYPE_WEIGHTS in tags: self.stashed_model_static_state = _export_static_state( self.tp_worker.model_runner.model @@ -177,6 +197,15 @@ class SchedulerUpdateWeightsMixin: if GPU_MEMORY_TYPE_KV_CACHE in tags: self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE) + if self.disaggregation_mode == DisaggregationMode.DECODE: + if hasattr(self, "disagg_decode_transfer_queue"): + self.disagg_decode_transfer_queue.resume_memory_occupation() + if hasattr(self, "disagg_decode_prealloc_queue"): + self.disagg_decode_prealloc_queue.resume_memory_occupation() + elif self.disaggregation_mode == DisaggregationMode.PREFILL: + if hasattr(self, "disagg_prefill_bootstrap_queue"): + self.disagg_prefill_bootstrap_queue.resume_memory_occupation() + return ResumeMemoryOccupationReqOutput() def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput): diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py index e5d42bed8..412293b30 100644 --- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py +++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py @@ -49,6 +49,8 @@ from sglang.srt.managers.io_struct import ( LoadLoRAAdapterReqOutput, LoRAUpdateOutput, OpenSessionReqInput, + PostProcessWeightsReqInput, + PostProcessWeightsReqOutput, ProfileReq, ProfileReqOutput, ProfileReqType, @@ -177,6 +179,9 @@ class TokenizerCommunicatorMixin: self.update_weights_from_ipc_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) + self.post_process_weights_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) self.get_weights_by_name_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) @@ -250,6 +255,10 @@ class TokenizerCommunicatorMixin: UpdateWeightsFromIPCReqOutput, self.update_weights_from_ipc_communicator.handle_recv, ), + ( + PostProcessWeightsReqOutput, + self.post_process_weights_communicator.handle_recv, + ), ( GetWeightsByNameReqOutput, self.get_weights_by_name_communicator.handle_recv, @@ -433,6 +442,17 @@ class TokenizerCommunicatorMixin: return success, message + async def post_process_weights( + self: TokenizerManager, + obj: PostProcessWeightsReqInput, + request: Optional[fastapi.Request] = None, + ) -> Tuple[bool, str]: + """Trigger post-processing hooks for weights after loading (e.g., Marlin conversion).""" + self.auto_create_handle_loop() + async with self.model_update_lock.writer_lock: + results = await self.post_process_weights_communicator(obj) + return _Communicator.merge_results(results) + async def init_weights_send_group_for_remote_instance( self, obj: InitWeightsSendGroupForRemoteInstanceReqInput, diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 49f63a198..e4cd0ff2b 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -27,6 +27,7 @@ from sglang.srt.managers.io_struct import ( InitWeightsSendGroupForRemoteInstanceReqInput, InitWeightsUpdateGroupReqInput, LoadLoRAAdapterReqInput, + PostProcessWeightsReqInput, SendWeightsToRemoteInstanceReqInput, UnloadLoRAAdapterReqInput, UpdateWeightFromDiskReqInput, @@ -175,6 +176,11 @@ class BaseTpWorker(ABC): success, message = self.model_runner.update_weights_from_ipc(recv_req) return success, message + def post_process_weights(self, recv_req: PostProcessWeightsReqInput): + """Perform optional post-processing on the updated model weights (e.g., Marlin conversion).""" + success, message = self.model_runner.post_process_weights(recv_req) + return success, message + def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): parameter = self.model_runner.get_weights_by_name( recv_req.name, recv_req.truncate_size diff --git a/python/sglang/srt/mem_cache/allocator.py b/python/sglang/srt/mem_cache/allocator.py index eaf29628b..bf74cbd12 100644 --- a/python/sglang/srt/mem_cache/allocator.py +++ b/python/sglang/srt/mem_cache/allocator.py @@ -287,6 +287,85 @@ def alloc_decode_kernel( tl.store(out_indices + pid, page * page_size) +def alloc_extend_torch_fallback( + prefix_lens_cpu: torch.Tensor, + seq_lens_cpu: torch.Tensor, + last_loc: torch.Tensor, + free_pages: torch.Tensor, + out_indices: torch.Tensor, + page_size: int, + debug_mode: bool = False, +): + extend_lens_cpu = (seq_lens_cpu - prefix_lens_cpu).to(torch.int64) + if extend_lens_cpu.numel() == 0: + return + + output_start_locs_cpu = torch.cumsum(extend_lens_cpu, dim=0) - extend_lens_cpu + num_pages_after = (seq_lens_cpu + page_size - 1) // page_size + num_pages_before = (prefix_lens_cpu + page_size - 1) // page_size + num_new_pages_cpu = num_pages_after - num_pages_before + page_start_locs_cpu = torch.cumsum(num_new_pages_cpu, dim=0) - num_new_pages_cpu + + total_new_pages = int(num_new_pages_cpu.sum().item()) + if total_new_pages > free_pages.numel(): + return + + if debug_mode: + assert int(extend_lens_cpu.sum().item()) == out_indices.numel() + + prefix_lens_list = prefix_lens_cpu.tolist() + seq_lens_list = seq_lens_cpu.tolist() + extend_lens_list = extend_lens_cpu.tolist() + out_start_list = output_start_locs_cpu.tolist() + page_start_list = page_start_locs_cpu.tolist() + num_new_pages_list = num_new_pages_cpu.tolist() + + device = out_indices.device + dtype = out_indices.dtype + offsets_page = torch.arange(page_size, device=device, dtype=dtype) + + for i, extend_len in enumerate(extend_lens_list): + if extend_len == 0: + continue + + pre_len = prefix_lens_list[i] + seq_len = seq_lens_list[i] + out_start = out_start_list[i] + page_start = page_start_list[i] + num_new_pages = num_new_pages_list[i] + + pre_mod = pre_len % page_size + part1 = min(extend_len, page_size - pre_mod) if pre_mod != 0 else 0 + if part1: + start_val = last_loc[i] + 1 + out_indices[out_start : out_start + part1] = start_val + torch.arange( + part1, device=device, dtype=dtype + ) + if part1 == extend_len: + continue + + ceil_pre_pages = (pre_len + page_size - 1) // page_size + full_pages_after = seq_len // page_size + num_full_pages = full_pages_after - ceil_pre_pages + if num_full_pages < 0: + num_full_pages = 0 + part2 = num_full_pages * page_size + if part2: + pages = free_pages[page_start : page_start + num_full_pages] + full_indices = (pages[:, None] * page_size + offsets_page).reshape(-1) + out_indices[out_start + part1 : out_start + part1 + part2] = full_indices + if part1 + part2 == extend_len: + continue + + part3 = extend_len - part1 - part2 + if part3: + last_page = free_pages[page_start + num_new_pages - 1] + out_indices[out_start + part1 + part2 : out_start + extend_len] = ( + last_page * page_size + + torch.arange(part3, device=device, dtype=dtype) + ) + + class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): """ An allocator managing the indices to kv cache data. @@ -349,11 +428,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): (last_loc + 1) % self.page_size == prefix_lens % self.page_size ) - self.seen_max_num_extend_tokens_next_power_of_2 = max( - self.seen_max_num_extend_tokens_next_power_of_2, - next_power_of_2(extend_num_tokens), - ) - bs = len(prefix_lens) if self.need_sort and extend_num_tokens // self.page_size + bs + 1 > len( self.free_pages @@ -363,16 +437,34 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): out_indices = torch.empty( (extend_num_tokens,), dtype=torch.int64, device=self.device ) - alloc_extend_kernel[(bs,)]( - prefix_lens, - seq_lens, - last_loc, - self.free_pages, - out_indices, - next_power_of_2(bs), - self.page_size, - self.seen_max_num_extend_tokens_next_power_of_2, - ) + + # Use PyTorch fallback for large extend_num_tokens to avoid slow Triton compilation + MAX_TRITON_EXTEND_TOKENS = 65536 # 64K + if next_power_of_2(extend_num_tokens) > MAX_TRITON_EXTEND_TOKENS: + alloc_extend_torch_fallback( + prefix_lens_cpu=prefix_lens_cpu, + seq_lens_cpu=seq_lens_cpu, + last_loc=last_loc, + free_pages=self.free_pages, + out_indices=out_indices, + page_size=self.page_size, + debug_mode=self.debug_mode, + ) + else: + self.seen_max_num_extend_tokens_next_power_of_2 = max( + self.seen_max_num_extend_tokens_next_power_of_2, + next_power_of_2(extend_num_tokens), + ) + alloc_extend_kernel[(bs,)]( + prefix_lens, + seq_lens, + last_loc, + self.free_pages, + out_indices, + next_power_of_2(bs), + self.page_size, + self.seen_max_num_extend_tokens_next_power_of_2, + ) if self.debug_mode: assert len(torch.unique(out_indices)) == len(out_indices) diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index f6cfca8b6..5d3cad059 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -11,10 +11,15 @@ import torch from sglang.srt.managers.cache_controller import HiCacheController, PrefetchOperation from sglang.srt.mem_cache.base_prefix_cache import MatchResult -from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool +from sglang.srt.mem_cache.memory_pool import ( + MHATokenToKVPool, + MLATokenToKVPool, + NSATokenToKVPool, +) from sglang.srt.mem_cache.memory_pool_host import ( MHATokenToKVPoolHost, MLATokenToKVPoolHost, + NSATokenToKVPoolHost, ) from sglang.srt.mem_cache.radix_cache import ( RadixCache, @@ -54,6 +59,16 @@ class HiRadixCache(RadixCache): server_args.hicache_mem_layout, allocator_type=server_args.hicache_storage_backend, ) + elif isinstance(self.kv_cache, NSATokenToKVPool): + # Check NSA before MLA since NSATokenToKVPool is a subclass of MLATokenToKVPool + self.token_to_kv_pool_host = NSATokenToKVPoolHost( + self.kv_cache, + server_args.hicache_ratio, + server_args.hicache_size, + self.page_size, + server_args.hicache_mem_layout, + allocator_type=server_args.hicache_storage_backend, + ) elif isinstance(self.kv_cache, MLATokenToKVPool): self.token_to_kv_pool_host = MLATokenToKVPoolHost( self.kv_cache, @@ -64,7 +79,7 @@ class HiRadixCache(RadixCache): allocator_type=server_args.hicache_storage_backend, ) else: - raise ValueError(f"HiRadixCache only supports MHA and MLA yet") + raise ValueError(f"HiRadixCache only supports MHA and MLA and NSA yet") self.tp_group = params.tp_cache_group self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 65d562a27..fe5547d7b 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -1678,7 +1678,8 @@ class NSATokenToKVPool(MLATokenToKVPool): with ( torch.cuda.use_mem_pool(self.custom_mem_pool) if self.custom_mem_pool - else nullcontext() + else nullcontext(), + self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE), ): self.index_k_with_scale_buffer = [ torch.zeros( @@ -1700,6 +1701,11 @@ class NSATokenToKVPool(MLATokenToKVPool): ) for _ in range(layer_num) ] + self.index_k_with_scale_buffer_ptrs = torch.tensor( + [x.data_ptr() for x in self.index_k_with_scale_buffer], + dtype=torch.uint64, + device=self.device, + ) self._finalize_allocation_log(size) def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor: @@ -1775,6 +1781,50 @@ class NSATokenToKVPool(MLATokenToKVPool): ] return data_ptrs, data_lens, item_lens + def get_cpu_copy(self, indices): + # First, save the kv_buffer (inherited from MLATokenToKVPool) + kv_cache_cpu = super().get_cpu_copy(indices) + + # Additionally, save the index_k_with_scale_buffer (page-indexed) + page_indices = indices[:: self.page_size] // self.page_size + torch.cuda.synchronize() + index_k_cpu = [] + chunk_size = self.cpu_offloading_chunk_size + # Convert chunk_size from token-level to page-level + page_chunk_size = max(1, chunk_size // self.page_size) + for layer_id in range(self.layer_num): + index_k_cpu.append([]) + for i in range(0, len(page_indices), page_chunk_size): + chunk_page_indices = page_indices[i : i + page_chunk_size] + idx_cpu = self.index_k_with_scale_buffer[layer_id][ + chunk_page_indices + ].to("cpu", non_blocking=True) + index_k_cpu[-1].append(idx_cpu) + torch.cuda.synchronize() + + return {"kv": kv_cache_cpu, "index_k": index_k_cpu} + + def load_cpu_copy(self, kv_cache_cpu_dict, indices): + # Restore the kv_buffer (inherited from MLATokenToKVPool) + super().load_cpu_copy(kv_cache_cpu_dict["kv"], indices) + + # Restore the index_k_with_scale_buffer (page-indexed) + page_indices = indices[:: self.page_size] // self.page_size + index_k_cpu = kv_cache_cpu_dict["index_k"] + torch.cuda.synchronize() + chunk_size = self.cpu_offloading_chunk_size + page_chunk_size = max(1, chunk_size // self.page_size) + for layer_id in range(self.layer_num): + for i in range(0, len(page_indices), page_chunk_size): + chunk_page_indices = page_indices[i : i + page_chunk_size] + idx_cpu = index_k_cpu[layer_id][i // page_chunk_size] + assert idx_cpu.shape[0] == len(chunk_page_indices) + idx_chunk = idx_cpu.to( + self.index_k_with_scale_buffer[0].device, non_blocking=True + ) + self.index_k_with_scale_buffer[layer_id][chunk_page_indices] = idx_chunk + torch.cuda.synchronize() + def get_kv_size_bytes(self): kv_size_bytes = super().get_kv_size_bytes() for index_k_cache in self.index_k_with_scale_buffer: diff --git a/python/sglang/srt/mem_cache/memory_pool_host.py b/python/sglang/srt/mem_cache/memory_pool_host.py index 46394158f..d99cc4b3b 100644 --- a/python/sglang/srt/mem_cache/memory_pool_host.py +++ b/python/sglang/srt/mem_cache/memory_pool_host.py @@ -15,7 +15,12 @@ from sglang.jit_kernel.hicache import ( from sglang.jit_kernel.hicache import ( transfer_hicache_one_layer as jit_transfer_hicache_one_layer, ) -from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool +from sglang.srt.mem_cache.memory_pool import ( + KVCache, + MHATokenToKVPool, + MLATokenToKVPool, + NSATokenToKVPool, +) from sglang.srt.utils import is_cuda, is_npu, is_xpu _is_cuda = is_cuda() @@ -1015,3 +1020,199 @@ class MLATokenToKVPoolHost(HostKVCache): else: raise ValueError(f"Unsupported layout: {self.layout}") return ptr_list, element_size_list + + +class NSATokenToKVPoolHost(MLATokenToKVPoolHost): + """ + Host memory pool for NSA (Native Sparse Attention) KV cache. + + NSA extends MLA with an additional index_k_with_scale_buffer that stores + sparse attention indexing information. This class ensures that buffer is + also backed up and restored during hicache operations. + """ + + device_pool: NSATokenToKVPool + + def __init__( + self, + device_pool: NSATokenToKVPool, + host_to_device_ratio: float, + host_size: int, + page_size: int, + layout: str, + pin_memory: bool = True, + device: str = "cpu", + allocator_type: str = "default", + ): + # Store NSA-specific attributes before calling parent __init__ + self.index_head_dim = device_pool.index_head_dim + self.quant_block_size = device_pool.quant_block_size + self.index_k_with_scale_buffer_dtype = device_pool.index_k_with_scale_buffer_dtype + + super().__init__( + device_pool, + host_to_device_ratio, + host_size, + page_size, + layout, + pin_memory, + device, + allocator_type, + ) + + # Initialize index buffer references and pointers for efficient transfer + self.index_data_refs = [ + self.index_k_with_scale_buffer[i] for i in range(self.layer_num) + ] + self.index_data_ptrs = torch.tensor( + [x.data_ptr() for x in self.index_data_refs], + dtype=torch.uint64, + device=self.device_pool.device, + ) + + def get_size_per_token(self): + # Get base MLA size + base_size = super().get_size_per_token() + + # Add NSA index buffer size per token + # index_k_with_scale_buffer shape per layer: (num_pages, page_size * (index_head_dim + index_head_dim // quant_block_size * 4)) + # Per token: (index_head_dim + index_head_dim // quant_block_size * 4) * dtype.itemsize * layer_num + index_size_per_token = ( + (self.index_head_dim + self.index_head_dim // self.quant_block_size * 4) + * self.index_k_with_scale_buffer_dtype.itemsize + * self.layer_num + ) + + return base_size + index_size_per_token + + def init_kv_buffer(self): + # Initialize base MLA kv_buffer + buffer = super().init_kv_buffer() + + # Initialize NSA index_k_with_scale_buffer on host + # Layout matches device pool: (num_pages, page_size * (index_head_dim + index_head_dim // quant_block_size * 4)) + index_buffer_second_dim = self.page_size * ( + self.index_head_dim + self.index_head_dim // self.quant_block_size * 4 + ) + self.index_stride_size = (self.index_head_dim + self.index_head_dim // self.quant_block_size * 4) * self.index_k_with_scale_buffer_dtype.itemsize + + alloc_func = ALLOC_MEMORY_FUNCS[self.device_pool.device] + self.index_k_with_scale_buffer = [ + alloc_func( + (self.page_num, index_buffer_second_dim), + dtype=self.index_k_with_scale_buffer_dtype, + device=self.device, + pin_memory=self.pin_memory, + allocator=self.allocator, + ) + for _ in range(self.layer_num) + ] + + return buffer + + def _load_indexer_to_device_per_layer( + self, device_pool, host_indices, device_indices, layer_id, io_backend + ): + """Load index_k_with_scale_buffer from host to device for a specific layer.""" + # Convert token indices to page indices + # host_indices and device_indices are token-level indices + # index_k_with_scale_buffer is page-level with shape (num_pages, page_size * dim) + + if io_backend == "kernel": + # Use page-level copy for index buffer + # Calculate page indices from token indices + page_indices_host = host_indices[:: self.page_size] // self.page_size + page_indices_device = device_indices[:: self.page_size] // self.page_size + + src_buffer = self.index_k_with_scale_buffer[layer_id] + dst_buffer = device_pool.index_k_with_scale_buffer[layer_id - device_pool.start_layer] + + # Copy each page + # for i in range(len(page_indices_host)): + # src_page_idx = page_indices_host[i].item() + # dst_page_idx = page_indices_device[i].item() + # dst_buffer[dst_page_idx].copy_(src_buffer[src_page_idx], non_blocking=True) + if self.layout == "layer_first": + transfer_kv_per_layer_mla( + src=src_buffer, + dst=dst_buffer, + src_indices=page_indices_host, + dst_indices=page_indices_device, + item_size=self.index_stride_size * self.page_size, + ) + else: + raise ValueError(f"Unsupported layout: {self.layout}") + + elif io_backend == "direct": + # Direct I/O copy for index buffer + page_indices_host = host_indices[:: self.page_size] // self.page_size + page_indices_device = device_indices[:: self.page_size] // self.page_size + + src_buffer = self.index_k_with_scale_buffer[layer_id] + dst_buffer = device_pool.index_k_with_scale_buffer[layer_id - device_pool.start_layer] + + for i in range(len(page_indices_host)): + src_page_idx = page_indices_host[i].item() + dst_page_idx = page_indices_device[i].item() + dst_buffer[dst_page_idx].copy_(src_buffer[src_page_idx], non_blocking=True) + else: + raise ValueError(f"Unsupported IO backend for NSA indexer: {io_backend}") + + def _backup_indexer_from_device_all_layer( + self, device_pool, host_indices, device_indices, io_backend + ): + """Backup index_k_with_scale_buffer from device to host for all layers.""" + # Convert token indices to page indices + page_indices_host = host_indices[:: self.page_size] // self.page_size + page_indices_device = device_indices[:: self.page_size] // self.page_size + + # if io_backend in ["kernel", "direct"]: + if io_backend == "kernel": + if self.layout == "layer_first": + transfer_kv_all_layer_mla( + src_layers=device_pool.index_k_with_scale_buffer_ptrs, + dst_layers=self.index_data_ptrs, + src_indices=page_indices_device, + dst_indices=page_indices_host, + item_size=self.index_stride_size * self.page_size, + num_layers=self.layer_num, + ) + else: + raise ValueError(f"Unsupported layout: {self.layout}") + elif io_backend == "direct": + for layer_id in range(self.layer_num): + src_buffer = device_pool.index_k_with_scale_buffer[layer_id] + dst_buffer = self.index_k_with_scale_buffer[layer_id] + + for i in range(len(page_indices_device)): + src_page_idx = page_indices_device[i].item() + dst_page_idx = page_indices_host[i].item() + dst_buffer[dst_page_idx].copy_(src_buffer[src_page_idx], non_blocking=True) + else: + raise ValueError(f"Unsupported IO backend for NSA indexer: {io_backend}") + + def load_to_device_per_layer( + self, device_pool, host_indices, device_indices, layer_id, io_backend + ): + """Load KV cache and index buffer from host to device for a specific layer.""" + # Load base MLA kv_buffer + super().load_to_device_per_layer( + device_pool, host_indices, device_indices, layer_id, io_backend + ) + # Load NSA index_k_with_scale_buffer + self._load_indexer_to_device_per_layer( + device_pool, host_indices, device_indices, layer_id, io_backend + ) + + def backup_from_device_all_layer( + self, device_pool, host_indices, device_indices, io_backend + ): + """Backup KV cache and index buffer from device to host for all layers.""" + # Backup base MLA kv_buffer + super().backup_from_device_all_layer( + device_pool, host_indices, device_indices, io_backend + ) + # Backup NSA index_k_with_scale_buffer + self._backup_indexer_from_device_all_layer( + device_pool, host_indices, device_indices, io_backend + ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 1d69c0582..d984c2e12 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -558,7 +558,8 @@ class ModelRunner(ModelRunnerKVCacheMixin): ) # Init routed experts capturer - self.init_routed_experts_capturer() + if not self.is_draft_worker: + self.init_routed_experts_capturer() if self.device == "cuda": self.init_cublas() @@ -2224,11 +2225,19 @@ class ModelRunner(ModelRunnerKVCacheMixin): output.expert_distribution_metrics = recorder_outputs.get("metrics") # Copy cached routing experts' buffers back to CPU cache - get_global_experts_capturer().on_forward_end( - forward_batch=forward_batch, - can_run_graph=output.can_run_graph, - cuda_graph_batch=getattr(self.graph_runner, "bs", None), - ) + if not self.is_draft_worker: + # In speculative decoding, num_tokens_per_bs > 1, so we need to pass + # the actual number of tokens per dp rank in cuda graph, not batch size. + cuda_graph_num_tokens = None + if getattr(self.graph_runner, "bs", None): + cuda_graph_num_tokens = ( + self.graph_runner.bs * self.graph_runner.num_tokens_per_bs + ) + get_global_experts_capturer().on_forward_end( + forward_batch=forward_batch, + can_run_graph=output.can_run_graph, + cuda_graph_batch=cuda_graph_num_tokens, + ) if self.eplb_manager is not None: self.eplb_manager.on_forward_pass_end() @@ -2436,6 +2445,41 @@ class ModelRunner(ModelRunnerKVCacheMixin): logger.error(f"IPC weight update failed: {e}") return False, str(e) + def post_process_weights(self, recv_req): + """ + Execute post-processing logic for model weights, such as Marlin quantization format conversion. + """ + from sglang.srt.model_loader.loader import device_loading_context + + target_device = torch.device("cuda", torch.cuda.current_device()) + + if recv_req.restore_weights_before_load: + for _, module in self.model.named_modules(): + quant_method = getattr(module, "quant_method", None) + + # Check if the module supports restoring weights + if quant_method is not None and hasattr( + quant_method, "restore_weights_before_loading" + ): + + with device_loading_context(module, target_device): + quant_method.restore_weights_before_loading(module) + + if recv_req.post_process_quantization: + # Iterate through all modules to apply specific post-loading processing + for _, module in self.model.named_modules(): + quant_method = getattr(module, "quant_method", None) + + # Check if the module supports quantization post-processing + if quant_method is not None and hasattr( + quant_method, "process_weights_after_loading" + ): + + # Apply the post-processing (e.g., repacking weights for Marlin kernel) + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) + + return True, "Success" def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]): params_dict = dict(model.named_parameters()) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index ed8cc7ada..b8f1026dd 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -159,6 +159,7 @@ from sglang.srt.utils import ( make_layers, use_intel_amx_backend, ) +from sglang.srt.layers.attention.hybrid_attn_backend import HybridAttnBackend _is_hip = is_hip() _is_cuda = is_cuda() @@ -434,6 +435,8 @@ def handle_attention_nsa(attn, forward_batch): backend = forward_batch.attn_backend if isinstance(backend, TboAttnBackend): # if enable tbo, get primary backend backend = backend.primary + if isinstance(backend, HybridAttnBackend): + backend = backend._select_backend(forward_batch.forward_mode) if hasattr(backend, "use_mha") and backend.use_mha: return AttnForwardMethod.MHA_ONE_SHOT return AttnForwardMethod.MLA @@ -2704,7 +2707,11 @@ class DeepseekV2AttentionMLA(nn.Module): ): k = k_nope.new_empty(*k_shape) concat_mla_k(k=k, k_nope=k_nope, k_rope=k_pe) - elif _is_cuda: + elif _is_cuda and all( + # (i.bit_count() == 1) == (is_power_of_two(i)) + i.bit_count() == 1 + for i in (k_shape[1], k_nope.shape[-1], k_pe.shape[-1]) + ): # fa3 mha support fp8 inputs if ( self.current_attention_backend == "fa3" diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index a7dbadec6..c83a41338 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -90,9 +90,6 @@ class Qwen2MLP(nn.Module): self.act_fn = SiluAndMul() def forward(self, x): - if get_global_server_args().rl_on_policy_target is not None: - x = x.bfloat16() - gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) @@ -279,11 +276,6 @@ class Qwen2Model(nn.Module): quant_config=quant_config, enable_tp=not is_dp_attention_enabled(), prefix=add_prefix("embed_tokens", prefix), - params_dtype=( - torch.float32 - if get_global_server_args().rl_on_policy_target is not None - else None - ), ) else: self.embed_tokens = PPMissingLayer() @@ -306,10 +298,8 @@ class Qwen2Model(nn.Module): if self.pp_group.is_last_rank: norm_kwargs = ( dict( - weight_dtype=torch.float32, cast_x_before_out_mul=True, - override_orig_dtype=torch.float32, - fp32_residual=True, + fp32_residual=False, ) if get_global_server_args().rl_on_policy_target is not None else {} diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 3ad9f6736..0b9c7f499 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -586,7 +586,17 @@ class Qwen2MoeModel(nn.Module): prefix=add_prefix("layers", prefix), ) if self.pp_group.is_last_rank: - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + norm_kwargs = ( + dict( + cast_x_before_out_mul=True, + fp32_residual=False, + ) + if get_global_server_args().rl_on_policy_target is not None + else {} + ) + self.norm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs + ) else: self.norm = PPMissingLayer(return_tuple=True) diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index 9220831f6..2b8303b54 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -90,8 +90,8 @@ class Qwen3Attention(nn.Module): norm_kwargs = ( dict( - weight_dtype=torch.float32, cast_x_before_out_mul=True, + fp32_residual=False, ) if get_global_server_args().rl_on_policy_target is not None else {} @@ -242,10 +242,8 @@ class Qwen3DecoderLayer(nn.Module): norm_kwargs = ( dict( - weight_dtype=torch.float32, cast_x_before_out_mul=True, - override_orig_dtype=torch.float32, - fp32_residual=True, + fp32_residual=False, ) if get_global_server_args().rl_on_policy_target is not None else {} @@ -276,14 +274,14 @@ class Qwen3DecoderLayer(nn.Module): hidden_states: torch.Tensor, forward_batch: ForwardBatch, residual: Optional[torch.Tensor], - **kwargs, + post_residual_addition: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention hidden_states, residual = self.layer_communicator.prepare_attn( hidden_states, residual, forward_batch, - **kwargs, + post_residual_addition=post_residual_addition, ) if hidden_states.shape[0] != 0: hidden_states = self.self_attn( diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index e11678a9e..e277d46f2 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -22,6 +22,7 @@ import math from typing import Any, Dict, Iterable, List, Optional, Tuple, TypeVar import torch +import torch.nn.functional as F from torch import nn from transformers import PretrainedConfig @@ -50,7 +51,7 @@ from sglang.srt.layers.moe import ( ) from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE -from sglang.srt.layers.moe.topk import TopK +from sglang.srt.layers.moe.topk import StandardTopKOutput, TopK from sglang.srt.layers.moe.utils import RoutingMethodType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention @@ -229,6 +230,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): use_grouped_topk=False, layer_id=layer_id, ) + self.top_k = config.num_experts_per_tok self.experts = get_moe_impl_class(quant_config)( num_experts=config.num_experts @@ -294,7 +296,22 @@ class Qwen3MoeSparseMoeBlock(nn.Module): # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - topk_output = self.topk(hidden_states, router_logits) + + if get_global_server_args().rl_on_policy_target is not None: + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk( + routing_weights, self.top_k, dim=-1 + ) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + topk_output = StandardTopKOutput( + topk_weights=routing_weights, + topk_ids=selected_experts, + router_logits=router_logits, + ) + else: + topk_output = self.topk(hidden_states, router_logits) + final_hidden_states = self.experts(hidden_states, topk_output) if ( self.tp_size > 1 @@ -475,13 +492,14 @@ class Qwen3MoeAttention(nn.Module): ) self.compatible_with_fused_kv_buffer = ( False if isinstance(self.rotary_emb, MRotaryEmbedding) else True - ) + ) and (get_global_server_args().rl_on_policy_target is None) self.compatible_with_fused_qk_norm_rope = ( not isinstance(self.rotary_emb, MRotaryEmbedding) ) and self.head_dim in (64, 128, 256) self.use_fused_qk_norm_rope = ( get_global_server_args().enable_fused_qk_norm_rope and self.compatible_with_fused_qk_norm_rope + and (get_global_server_args().rl_on_policy_target is None) ) self._used_fused_qk_norm_rope_last_call = False @@ -494,8 +512,16 @@ class Qwen3MoeAttention(nn.Module): prefix=add_prefix("attn", prefix), ) - self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) - self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + norm_kwargs = ( + dict( + cast_x_before_out_mul=True, + fp32_residual=False, + ) + if get_global_server_args().rl_on_policy_target is not None + else {} + ) + self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs) + self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs) self.alt_stream = alt_stream def op_prepare(self, state): @@ -736,9 +762,19 @@ class Qwen3MoeDecoderLayer(nn.Module): quant_config=quant_config, prefix=add_prefix("mlp", prefix), ) - self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + norm_kwargs = ( + dict( + cast_x_before_out_mul=True, + fp32_residual=False, + ) + if get_global_server_args().rl_on_policy_target is not None + else {} + ) + self.input_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs + ) self.post_attention_layernorm = RMSNorm( - config.hidden_size, eps=config.rms_norm_eps + config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs ) self.layer_communicator = LayerCommunicator( diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py index 079f45843..218e32362 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py @@ -397,28 +397,68 @@ class Qwen3VLMoeVisionModel(nn.Module, RotaryPosMixin): return cos_combined, sin_combined def fast_pos_embed_interpolate(self, grid_thw): - patch_pos_embeds_permute = [] - m_size = self.spatial_merge_size + grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] + num_grid_per_side = int(self.num_position_embeddings**0.5) + device = self.pos_embed.weight.device + + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + + for t, h, w in zip(grid_ts, grid_hs, grid_ws): + h_idxs = torch.linspace(0, num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, num_grid_per_side - 1, w) + + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + h_idxs_ceil = (h_idxs.int() + 1).clip(max=num_grid_per_side - 1) + w_idxs_ceil = (w_idxs.int() + 1).clip(max=num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + base_h = h_idxs_floor * num_grid_per_side + base_h_ceil = h_idxs_ceil * num_grid_per_side + + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] + + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] - embeds = torch.arange(self.num_grid, device=self.pos_embed.weight.device) - embeds = ( - self.pos_embed(embeds) - .permute(1, 0) - .reshape(1, -1, self.num_grid_per_side, self.num_grid_per_side) + for i in range(4): + idx_list[i].extend(indices[i].tolist()) + weight_list[i].extend(weights[i].tolist()) + + idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) + weight_tensor = torch.tensor( + weight_list, dtype=self.pos_embed.weight.dtype, device=device ) - for t, h, w in grid_thw: - pos_embed = torch.nn.functional.interpolate( - embeds, size=(h, w), mode="bilinear", align_corners=self.align_corners - ) - pos_embed = pos_embed.reshape( - -1, - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, + pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + patch_pos_embeds = patch_pos_embeds.split( + [h * w for h, w in zip(grid_hs, grid_ws)] + ) + + patch_pos_embeds_permute = [] + merge_size = self.spatial_merge_size + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + pos_embed = pos_embed.repeat(t, 1) + pos_embed = ( + pos_embed.view( + t, h // merge_size, merge_size, w // merge_size, merge_size, -1 + ) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) ) - pos_embed = pos_embed.permute(1, 3, 2, 4, 0) - pos_embed = pos_embed.flatten(0, 3).repeat(t, 1) patch_pos_embeds_permute.append(pos_embed) return torch.cat(patch_pos_embeds_permute) @@ -610,14 +650,19 @@ class Qwen3LLMModel(Qwen3Model): hidden_states + residual if residual is not None else hidden_states ) + deepstack_embeds = None + if input_deepstack_embeds is not None: + prev_layer_idx = layer_idx - 1 + if prev_layer_idx in self.deepstack_embed_to_decoder_layer: + sep = self.hidden_size * prev_layer_idx + deepstack_embeds = input_deepstack_embeds[ + :, sep : sep + self.hidden_size + ] + # SGLang applies residual at the START of the next layer, not at the END like HuggingFace. # See: https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py#L549 # To match HF behavior, deepstack must be added AFTER residual: (hidden_states + residual) + deepstack # The order matters because addition with different tensors is not associative in practice. - # Deepstack for prev_layer is applied at the start of current layer via post_residual_addition. - deepstack_embeds = self.get_deepstack_embeds( - layer_idx - 1, input_deepstack_embeds - ) hidden_states, residual = layer( positions, hidden_states, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index a2b26e0e0..72db29801 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -527,6 +527,7 @@ class ServerArgs: cuda_graph_max_bs: Optional[int] = None cuda_graph_bs: Optional[List[int]] = None disable_cuda_graph: bool = False + disable_draft_cuda_graph: bool = False disable_cuda_graph_padding: bool = False enable_profile_cuda_graph: bool = False enable_cudagraph_gc: bool = False @@ -3980,6 +3981,11 @@ class ServerArgs: action="store_true", help="Disable cuda graph.", ) + parser.add_argument( + "--disable-draft-cuda-graph", + action="store_true", + help="Disable cuda graph for draft model in speculative decoding.", + ) parser.add_argument( "--disable-cuda-graph-padding", action="store_true", diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index 5fe45086c..c95fbd0f6 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -341,7 +341,10 @@ class EAGLEDraftCudaGraphRunner: self.seq_lens.fill_(self.seq_len_fill_value) self.out_cache_loc.zero_() self.positions.zero_() - + self.topk_p.zero_() + self.topk_index.zero_() + self.hidden_states.zero_() + self.req_pool_indices.zero_() num_tokens = bs * self.num_tokens_per_bs # Common inputs @@ -350,8 +353,8 @@ class EAGLEDraftCudaGraphRunner: forward_batch.out_cache_loc ) self.positions[:raw_num_token].copy_(forward_batch.positions) - self.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p) - self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index) + self.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p.clamp(0, 1)) + self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index.clamp(0, self.model_runner.model_config.vocab_size - 1)) self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states) self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index 1bf3816e9..b5b41dba4 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -778,6 +778,10 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): self.topk_index = self.topk_index[: len(new_indices)] self.hidden_states = self.hidden_states[: len(new_indices)] self.verified_id = self.verified_id[: len(new_indices)] + if self.accept_length is not None: + self.accept_length = self.accept_length[: len(new_indices)] + if self.accept_length_cpu is not None: + self.accept_length_cpu = self.accept_length_cpu[: len(new_indices)] else: # in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index` self.topk_p = self.topk_p[new_indices] @@ -809,6 +813,27 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0) self.topk_p = torch.cat([self.topk_p, spec_info.topk_p]) self.topk_index = torch.cat([self.topk_index, spec_info.topk_index]) + if self.accept_length is not None and spec_info.accept_length is not None: + self.accept_length = torch.cat( + [self.accept_length, spec_info.accept_length] + ) + self.accept_length_cpu = self.accept_length.tolist() + elif self.accept_length is not None: + zeros = torch.zeros( + [spec_info.verified_id.shape[0]], + dtype=self.accept_length.dtype, + device=self.accept_length.device, + ) + self.accept_length = torch.cat([self.accept_length, zeros]) + self.accept_length_cpu = self.accept_length.tolist() + elif spec_info.accept_length is not None: + zeros = torch.zeros( + [self.verified_id.shape[0]], + dtype=self.accept_length.dtype, + device=self.accept_length.device, + ) + self.accept_length = torch.cat([zeros, spec_info.accept_length]) + self.accept_length_cpu = self.accept_length.tolist() @dataclass diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index a702df4f8..61d9ae366 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -231,7 +231,7 @@ class EAGLEWorker(TpModelWorker): self.cuda_graph_runner = None self.cuda_graph_runner_for_draft_extend = None - if self.server_args.disable_cuda_graph: + if self.server_args.disable_cuda_graph or self.server_args.disable_draft_cuda_graph: return Device2DraftCudaGraphRunner = { diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 8560246c6..13db860dc 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -2224,6 +2224,8 @@ class SafeUnpickler(pickle.Unpickler): "sglang.srt.model_executor.model_runner.", "sglang.srt.layers.", "sglang.srt.utils.", + # --- slime --- + "slime.", } DENY_CLASSES = {