JustinTX's picture
Add files using upload-large-folder tool
d7b3a74 verified
diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py
index aa10cb08d..d41c31a09 100644
--- a/python/sglang/srt/configs/model_config.py
+++ b/python/sglang/srt/configs/model_config.py
@@ -268,6 +268,12 @@ class ModelConfig:
):
self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN"
+ if (
+ is_draft_model
+ and self.hf_config.architectures[0] == "DeepseekV32ForCausalLM"
+ ):
+ self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN"
+
if is_draft_model and self.hf_config.architectures[0] == "Glm4MoeForCausalLM":
self.hf_config.architectures[0] = "Glm4MoeForCausalLMNextN"
diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py
index 51af67636..3ec1778ed 100644
--- a/python/sglang/srt/disaggregation/decode.py
+++ b/python/sglang/srt/disaggregation/decode.py
@@ -21,6 +21,7 @@ Life cycle of a request in the decode server
from __future__ import annotations
import logging
+import os
import time
from collections import deque
from dataclasses import dataclass
@@ -315,6 +316,16 @@ class DecodePreallocQueue:
)
return kv_manager
+ def release_memory_occupation(self):
+ self.queue.clear()
+ self.retracted_queue.clear()
+ if hasattr(self.kv_manager, "deregister_buffer_to_engine"):
+ self.kv_manager.deregister_buffer_to_engine()
+
+ def resume_memory_occupation(self):
+ if hasattr(self.kv_manager, "register_buffer_to_engine"):
+ self.kv_manager.register_buffer_to_engine()
+
def add(self, req: Req, is_retracted: bool = False) -> None:
"""Add a request to the pending queue."""
if self._check_if_req_exceed_kv_capacity(req):
@@ -419,12 +430,37 @@ class DecodePreallocQueue:
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
)
+ # Bootstrap timeout: if a request has been stuck in Bootstrapping for too long, treat it as failed.
+ bootstrap_timeout = float(
+ os.environ.get("SGLANG_DISAGGREGATION_TRANSFER_TIMEOUT", "600")
+ )
+ now = time.perf_counter()
+
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
if rids_to_check is not None and decode_req.req.rid not in rids_to_check:
continue
if poll == KVPoll.Bootstrapping:
- pass
+ # Check for bootstrap timeout
+ entry_time = getattr(
+ decode_req.req.time_stats,
+ "decode_prealloc_queue_entry_time",
+ None,
+ )
+ if entry_time is not None and (now - entry_time) > bootstrap_timeout:
+ error_message = (
+ f"Decode bootstrap timed out after {now - entry_time:.1f}s "
+ f"for request rank={self.tp_rank} "
+ f"{decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
+ )
+ logger.error(error_message)
+ prepare_abort(
+ decode_req.req,
+ error_message,
+ status_code=HTTPStatus.GATEWAY_TIMEOUT,
+ )
+ if self.scheduler.enable_metrics:
+ self.scheduler.metrics_collector.increment_bootstrap_failed_reqs()
elif poll == KVPoll.WaitingForInput:
decode_req.waiting_for_input = True
elif poll == KVPoll.Failed:
@@ -776,6 +812,13 @@ class DecodeTransferQueue:
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
)
+ # Transfer timeout: if a request has been in the transfer queue for too long
+ # (e.g., stuck in Bootstrapping/WaitingForInput/Transferring), treat it as failed.
+ transfer_timeout = float(
+ os.environ.get("SGLANG_DISAGGREGATION_TRANSFER_TIMEOUT", "600")
+ )
+ now = time.perf_counter()
+
transferred_reqs = []
indices_to_remove = set()
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
@@ -811,7 +854,33 @@ class DecodeTransferQueue:
KVPoll.WaitingForInput,
KVPoll.Transferring,
]:
- pass
+ # Check for transfer timeout
+ entry_time = getattr(
+ decode_req.req.time_stats,
+ "decode_transfer_queue_entry_time",
+ None,
+ )
+ if entry_time is not None and (now - entry_time) > transfer_timeout:
+ error_message = (
+ f"Decode transfer timed out after {now - entry_time:.1f}s "
+ f"(state={poll}) for request rank={self.tp_rank} "
+ f"{decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
+ )
+ logger.error(error_message)
+ prepare_abort(
+ decode_req.req,
+ error_message,
+ status_code=HTTPStatus.GATEWAY_TIMEOUT,
+ )
+ self.scheduler.stream_output(
+ [decode_req.req], decode_req.req.return_logprob
+ )
+ release_kv_cache(
+ decode_req.req, self.tree_cache, is_insert=False
+ )
+ indices_to_remove.add(i)
+ if self.scheduler.enable_metrics:
+ self.scheduler.metrics_collector.increment_transfer_failed_reqs()
else:
raise ValueError(f"Unexpected poll case: {poll}")
@@ -827,6 +896,14 @@ class DecodeTransferQueue:
return transferred_reqs
+ def release_memory_occupation(self):
+ """Clean up all in-flight transfers before releasing GPU memory."""
+ self.queue.clear()
+
+ def resume_memory_occupation(self):
+ """Resume after GPU memory re-allocation. Queue was already cleared on release."""
+ pass
+
class SchedulerDisaggregationDecodeMixin:
@@ -1004,7 +1081,15 @@ class SchedulerDisaggregationDecodeMixin:
resumed_reqs = self.disagg_decode_prealloc_queue.resume_retracted_reqs()
self.waiting_queue.extend(resumed_reqs)
if len(self.disagg_decode_prealloc_queue.retracted_queue) > 0:
- # if there are still retracted requests, we do not allocate new requests
+ # Still have retracted requests that couldn't resume (not enough memory).
+ # Don't accept new requests (pop_preallocated) — they would consume memory
+ # that retracted requests need.
+ # But DO drain completed transfers: their KV is already committed, and
+ # moving them to waiting_queue frees the reserved-decode-token budget
+ # in _allocatable_tokens(), which may unblock resume on the next iteration.
+ # Without this, completed transfers hold memory indefinitely → deadlock.
+ alloc_reqs = self.disagg_decode_transfer_queue.pop_transferred()
+ self.waiting_queue.extend(alloc_reqs)
return
if not hasattr(self, "polling_count"):
diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py
index 32e8c0b69..dc93c5c5f 100644
--- a/python/sglang/srt/disaggregation/mooncake/conn.py
+++ b/python/sglang/srt/disaggregation/mooncake/conn.py
@@ -253,6 +253,19 @@ class MooncakeKVManager(CommonKVManager):
self.kv_args.state_data_ptrs, self.kv_args.state_data_lens
)
+ def deregister_buffer_to_engine(self):
+ # Batch deregister KV data buffers
+ if self.kv_args.kv_data_ptrs:
+ self.engine.batch_deregister(self.kv_args.kv_data_ptrs)
+
+ # Batch deregister auxiliary data buffers
+ if self.kv_args.aux_data_ptrs:
+ self.engine.batch_deregister(self.kv_args.aux_data_ptrs)
+
+ # Batch deregister state/extra pool data buffers
+ if self.kv_args.state_data_ptrs:
+ self.engine.batch_deregister(self.kv_args.state_data_ptrs)
+
def _transfer_data(self, mooncake_session_id, transfer_blocks):
if not transfer_blocks:
return 0
diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py
index a6eed743a..191b0977f 100644
--- a/python/sglang/srt/disaggregation/prefill.py
+++ b/python/sglang/srt/disaggregation/prefill.py
@@ -20,6 +20,7 @@ Life cycle of a request in the prefill server
from __future__ import annotations
import logging
+import os
import time
from collections import deque
from http import HTTPStatus
@@ -250,6 +251,12 @@ class PrefillBootstrapQueue:
[req.disagg_kv_sender for req in self.queue], self.gloo_group
)
+ # Bootstrap timeout: if a request has been stuck in Bootstrapping for too long, treat it as failed.
+ bootstrap_timeout = float(
+ os.environ.get("SGLANG_DISAGGREGATION_TRANSFER_TIMEOUT", "600")
+ )
+ now = time.perf_counter()
+
for i, (req, poll) in enumerate(zip(self.queue, polls)):
if rids_to_check is not None:
# if req not in reqs_info_to_check, skip
@@ -257,6 +264,27 @@ class PrefillBootstrapQueue:
continue
if poll == KVPoll.Bootstrapping:
+ # Check for bootstrap timeout
+ entry_time = getattr(
+ req.time_stats,
+ "prefill_bootstrap_queue_entry_time",
+ None,
+ )
+ if entry_time is not None and (now - entry_time) > bootstrap_timeout:
+ error_message = (
+ f"Prefill bootstrap timed out after {now - entry_time:.1f}s "
+ f"for request rank={self.tp_rank} "
+ f"{req.rid=} {req.bootstrap_room=}"
+ )
+ logger.error(error_message)
+ prepare_abort(
+ req, error_message, status_code=HTTPStatus.GATEWAY_TIMEOUT
+ )
+ self.scheduler.stream_output([req], req.return_logprob)
+ indices_to_remove.add(i)
+ failed_reqs.append(req)
+ if self.scheduler.enable_metrics:
+ self.scheduler.metrics_collector.increment_bootstrap_failed_reqs()
continue
elif poll == KVPoll.Failed:
error_message = f"Prefill bootstrap failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}"
@@ -306,6 +334,15 @@ class PrefillBootstrapQueue:
else:
return bootstrapped_reqs, failed_reqs
+ def release_memory_occupation(self):
+ self.queue.clear()
+ if hasattr(self.kv_manager, "deregister_buffer_to_engine"):
+ self.kv_manager.deregister_buffer_to_engine()
+
+ def resume_memory_occupation(self):
+ if hasattr(self.kv_manager, "register_buffer_to_engine"):
+ self.kv_manager.register_buffer_to_engine()
+
class SchedulerDisaggregationPrefillMixin:
"""
@@ -535,6 +572,13 @@ class SchedulerDisaggregationPrefillMixin:
self.attn_tp_cpu_group,
)
+ # Transfer timeout: if a request has been in the inflight queue for too long
+ # (e.g., stuck in WaitingForInput/Transferring), treat it as failed.
+ transfer_timeout = float(
+ os.environ.get("SGLANG_DISAGGREGATION_TRANSFER_TIMEOUT", "600")
+ )
+ now = time.perf_counter()
+
undone_reqs: List[Req] = []
# Check .poll() for the reqs in disagg_prefill_inflight_queue. If Success, respond to the client and remove it from the queue
for req, poll in zip(self.disagg_prefill_inflight_queue, polls):
@@ -547,7 +591,30 @@ class SchedulerDisaggregationPrefillMixin:
assert poll == KVPoll.Success or poll == KVPoll.Failed
if poll in [KVPoll.WaitingForInput, KVPoll.Transferring]:
- undone_reqs.append(req)
+ # Check for transfer timeout
+ entry_time = getattr(
+ req.time_stats,
+ "prefill_transfer_queue_entry_time",
+ None,
+ )
+ if entry_time is not None and (now - entry_time) > transfer_timeout:
+ error_message = (
+ f"Prefill transfer timed out after {now - entry_time:.1f}s "
+ f"(state={poll}) for request rank={self.tp_rank} "
+ f"{req.rid=} {req.bootstrap_room=}"
+ )
+ logger.error(error_message)
+ release_kv_cache(req, self.tree_cache) # unlock the tree
+ prepare_abort(
+ req, error_message, status_code=HTTPStatus.GATEWAY_TIMEOUT
+ )
+ if hasattr(req.disagg_kv_sender, "clear"):
+ req.disagg_kv_sender.clear()
+ done_reqs.append(req)
+ if self.enable_metrics:
+ self.metrics_collector.increment_transfer_failed_reqs()
+ else:
+ undone_reqs.append(req)
elif poll == KVPoll.Success: # transfer done
release_kv_cache(req, self.tree_cache) # unlock the tree
req.finished_reason = FINISH_LENGTH(length=0)
diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py
index 0478526ef..cfb1aa669 100644
--- a/python/sglang/srt/distributed/parallel_state.py
+++ b/python/sglang/srt/distributed/parallel_state.py
@@ -1797,7 +1797,10 @@ def get_tensor_model_parallel_world_size():
def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
- return get_tp_group().rank_in_group
+ try:
+ return get_tp_group().rank_in_group
+ except Exception:
+ return 0
def get_pipeline_model_parallel_world_size():
diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py
index 6f69fd19b..da20ac2ed 100644
--- a/python/sglang/srt/entrypoints/engine.py
+++ b/python/sglang/srt/entrypoints/engine.py
@@ -49,6 +49,7 @@ from sglang.srt.managers.io_struct import (
InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput,
MultimodalDataInputFormat,
+ PostProcessWeightsReqInput,
ReleaseMemoryOccupationReqInput,
ResumeMemoryOccupationReqInput,
RpcReqInput,
@@ -593,6 +594,24 @@ class Engine(EngineBase):
self.tokenizer_manager.update_weights_from_ipc(obj, None)
)
+ def post_process_weights(
+ self,
+ restore_weights_before_load: bool = False,
+ post_process_quantization: bool = False,
+ ):
+ """
+ Optional post-processing for updated weights (e.g., Marlin conversion).
+ Should be called after weight update is finished.
+ """
+ obj = PostProcessWeightsReqInput(
+ restore_weights_before_load=restore_weights_before_load,
+ post_process_quantization=post_process_quantization,
+ )
+
+ return self.loop.run_until_complete(
+ self.tokenizer_manager.post_process_weights(obj, None)
+ )
+
def get_weights_by_name(self, name: str, truncate_size: int = 100):
"""Get weights by parameter name."""
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py
index 88705cc35..c8dc052f1 100644
--- a/python/sglang/srt/entrypoints/http_server.py
+++ b/python/sglang/srt/entrypoints/http_server.py
@@ -107,6 +107,7 @@ from sglang.srt.managers.io_struct import (
OpenSessionReqInput,
ParseFunctionCallReq,
PauseGenerationReqInput,
+ PostProcessWeightsReqInput,
ProfileReqInput,
ReleaseMemoryOccupationReqInput,
ResumeMemoryOccupationReqInput,
@@ -957,6 +958,21 @@ async def update_weights_from_ipc(obj: UpdateWeightsFromIPCReqInput, request: Re
else:
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
+@app.post("/post_process_weights")
+async def post_process_weights(req: PostProcessWeightsReqInput, request: Request):
+ """
+ Optional post-processing for updated weights (e.g., Marlin conversion).
+ This should be called selectively after `update_weights_from_distributed/update_weights_from_tensor`.
+ """
+ success, message = await _global_state.tokenizer_manager.post_process_weights(
+ req, request
+ )
+
+ content = {"success": success, "message": message}
+ return ORJSONResponse(
+ content, status_code=200 if success else HTTPStatus.BAD_REQUEST
+ )
+
@app.post("/update_weight_version")
async def update_weight_version(obj: UpdateWeightVersionReqInput, request: Request):
diff --git a/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py b/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py
index d6c499df0..565004260 100644
--- a/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py
+++ b/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py
@@ -613,7 +613,6 @@ def _get_k_and_s_triton(
page_indices,
k_out,
s_out,
- seq_len,
page_size,
buf_numel_per_page,
index_head_dim,
@@ -630,7 +629,6 @@ def _get_k_and_s_triton_kernel(
page_indices_ptr,
k_out_ptr,
s_out_ptr,
- seq_len: tl.constexpr,
page_size: tl.constexpr,
buf_numel_per_page: tl.constexpr,
index_head_dim: tl.constexpr,
diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py
index c9e82e4b1..f2584546a 100644
--- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py
+++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py
@@ -3,6 +3,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+import os
import torch
from einops import rearrange
@@ -178,7 +179,7 @@ class Indexer(MultiPlatformOp):
max_position=max_position_embeddings,
base=rope_theta, # type: ignore
rope_scaling=rope_scaling,
- is_neox_style=True,
+ is_neox_style=True if os.environ.get("INDEXER_ROPE_NEOX_STYLE", "1") == "1" else False,
device=get_global_server_args().device,
)
self.block_size = block_size
@@ -188,6 +189,9 @@ class Indexer(MultiPlatformOp):
@torch.compile(dynamic=True)
def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor):
weights, _ = self.weights_proj(x.float())
+ if weights.shape[1] < 32:
+ assert 32 % weights.shape[1] == 0
+ weights = weights.repeat_interleave(32 // weights.shape[1], dim=1)
weights = weights * self.n_heads**-0.5
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
return weights
@@ -837,6 +841,9 @@ class Indexer(MultiPlatformOp):
query, key = self._get_q_k_bf16(
q_lora, x, positions, enable_dual_stream, forward_batch=forward_batch
)
+ if query.shape[1] < 32:
+ assert 32 % query.shape[1] == 0
+ query = query.repeat_interleave(32//query.shape[1], dim=1)
if enable_dual_stream:
current_stream = torch.cuda.current_stream()
diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py
index 15df851eb..1636ed706 100644
--- a/python/sglang/srt/layers/communicator.py
+++ b/python/sglang/srt/layers/communicator.py
@@ -371,10 +371,13 @@ class LayerCommunicator:
residual: torch.Tensor,
forward_batch: ForwardBatch,
captured_last_layer_outputs: Optional[List[torch.Tensor]] = None,
- **kwargs,
+ post_residual_addition: Optional[torch.Tensor] = None,
):
hidden_states, residual = self.prepare_attn(
- hidden_states, residual, forward_batch, **kwargs
+ hidden_states,
+ residual,
+ forward_batch,
+ post_residual_addition=post_residual_addition,
)
if captured_last_layer_outputs is not None:
gathered_last_layer_output = self._communicate_simple_fn(
@@ -394,7 +397,7 @@ class LayerCommunicator:
residual: torch.Tensor,
forward_batch: ForwardBatch,
quant_format: str = "",
- **kwargs,
+ post_residual_addition: Optional[torch.Tensor] = None,
):
if get_attn_tp_context().input_scattered:
hidden_states, residual = self._tp_reduce_scatter(
@@ -444,7 +447,7 @@ class LayerCommunicator:
)
else:
- hidden_states = self.input_layernorm(hidden_states, **kwargs)
+ hidden_states = self.input_layernorm(hidden_states)
else:
if _use_aiter and _is_gfx95_supported and ("mxfp4" in quant_format):
@@ -478,7 +481,7 @@ class LayerCommunicator:
hidden_states, residual = self.input_layernorm(
hidden_states,
residual,
- **kwargs,
+ post_residual_addition,
)
hidden_states = self._communicate_simple_fn(
diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py
index 7bef9d2ab..5926ff7f5 100644
--- a/python/sglang/srt/layers/layernorm.py
+++ b/python/sglang/srt/layers/layernorm.py
@@ -83,15 +83,12 @@ class RMSNorm(MultiPlatformOp):
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
cast_x_before_out_mul: bool = False,
- fp32_residual: bool = False,
- weight_dtype: Optional = None,
- override_orig_dtype: Optional = None,
+ fp32_residual: bool = True,
) -> None:
super().__init__()
self.cast_x_before_out_mul = cast_x_before_out_mul
self.fp32_residual = fp32_residual
- self.override_orig_dtype = override_orig_dtype
- self.weight = nn.Parameter(torch.ones(hidden_size, dtype=weight_dtype))
+ self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.hidden_size = hidden_size
self.variance_size_override = (
@@ -104,16 +101,16 @@ class RMSNorm(MultiPlatformOp):
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
- **kwargs,
+ post_residual_addition: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if self.variance_size_override is not None:
- return self.forward_native(x, residual, **kwargs)
+ return self.forward_native(x, residual, post_residual_addition)
if is_batch_invariant_mode_enabled():
if (
residual is not None
or get_global_server_args().rl_on_policy_target == "fsdp"
):
- return self.forward_native(x, residual, **kwargs)
+ return self.forward_native(x, residual, post_residual_addition)
return rms_norm_batch_invariant(
x,
self.weight.data,
@@ -124,7 +121,6 @@ class RMSNorm(MultiPlatformOp):
# but right now we can only have hidden_states+(residual+post_residual_addition).
# (hidden_states+residual)+post_residual_addition != hidden_states+(residual+post_residual_addition),
# we probably need to add another parameter to fused_add_rmsnorm
- post_residual_addition = kwargs.get("post_residual_addition")
if post_residual_addition is not None:
residual = residual + post_residual_addition
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
@@ -136,9 +132,11 @@ class RMSNorm(MultiPlatformOp):
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
- **kwargs,
+ post_residual_addition: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if residual is not None:
+ if post_residual_addition is not None:
+ residual = residual + post_residual_addition
out, _, residual_out = torch_npu.npu_add_rms_norm(
residual, x, self.weight.data, self.variance_epsilon
)
@@ -149,9 +147,11 @@ class RMSNorm(MultiPlatformOp):
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
- **kwargs,
+ post_residual_addition: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if residual is not None:
+ if post_residual_addition is not None:
+ residual = residual + post_residual_addition
residual_out = torch.empty_like(x)
output = torch.empty_like(x)
fused_add_rms_norm(
@@ -169,12 +169,14 @@ class RMSNorm(MultiPlatformOp):
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
- **kwargs,
+ post_residual_addition: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if not x.is_contiguous():
# NOTE: Remove this if aiter kernel supports discontinuous input
x = x.contiguous()
if residual is not None:
+ if post_residual_addition is not None:
+ residual = residual + post_residual_addition
out = torch.empty_like(x)
residual_out = torch.empty_like(x)
fused_add_rms_norm(
@@ -189,27 +191,23 @@ class RMSNorm(MultiPlatformOp):
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
- **kwargs,
+ post_residual_addition: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if not x.is_contiguous():
x = x.contiguous()
- orig_dtype = self.override_orig_dtype or x.dtype
- post_residual_addition = kwargs.get("post_residual_addition")
+ orig_dtype = x.dtype
+
+ if residual is not None and not self.fp32_residual:
+ x = x + residual
+ if post_residual_addition is not None:
+ x = x + post_residual_addition
+ residual = x.clone()
x = x.to(torch.float32)
- if residual is not None:
- x = (
- x
- + residual.to(torch.float32)
- + (
- post_residual_addition.to(torch.float32)
- if post_residual_addition is not None
- else 0.0
- )
- )
- if self.fp32_residual:
- residual = x.clone()
- else:
- residual = x.to(orig_dtype)
+ if residual is not None and self.fp32_residual:
+ x = x + residual.to(torch.float32)
+ if post_residual_addition is not None:
+ x = x + post_residual_addition.to(torch.float32)
+ residual = x.to(orig_dtype)
hidden_size = x.shape[-1]
if hidden_size != self.hidden_size:
@@ -246,10 +244,12 @@ class RMSNorm(MultiPlatformOp):
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
- **kwargs,
+ post_residual_addition: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if _is_cpu_amx_available:
if residual is not None:
+ if post_residual_addition is not None:
+ residual = residual + post_residual_addition
torch.ops.sgl_kernel.fused_add_rmsnorm_cpu(
x, residual, self.weight.data, self.variance_epsilon
)
@@ -258,17 +258,19 @@ class RMSNorm(MultiPlatformOp):
x, self.weight.data, self.variance_epsilon
)
else:
- return self.forward_native(x, residual, **kwargs)
+ return self.forward_native(x, residual, post_residual_addition)
def forward_xpu(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
- **kwargs,
+ post_residual_addition: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if self.variance_size_override is not None:
- return self.forward_native(x, residual, **kwargs)
+ return self.forward_native(x, residual, post_residual_addition)
if residual is not None:
+ if post_residual_addition is not None:
+ residual = residual + post_residual_addition
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
return x, residual
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
@@ -278,6 +280,7 @@ class RMSNorm(MultiPlatformOp):
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
+ post_residual_addition: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Forward method with allreduce fusion, prioritizing flashinfer fused operations
@@ -289,6 +292,8 @@ class RMSNorm(MultiPlatformOp):
)
if get_tensor_model_parallel_world_size() > 1:
+ if post_residual_addition is not None:
+ x = x + post_residual_addition
fused_result = flashinfer_allreduce_residual_rmsnorm(
input_tensor=x,
residual=residual,
@@ -298,7 +303,7 @@ class RMSNorm(MultiPlatformOp):
if fused_result[0] is not None:
return fused_result
- return self.forward(x, residual)
+ return self.forward(x, residual, post_residual_addition)
class LayerNorm(MultiPlatformOp):
@@ -323,7 +328,6 @@ class LayerNorm(MultiPlatformOp):
def forward_cuda(
self,
x: torch.Tensor,
- **kwargs,
) -> torch.Tensor:
if (
_flashinfer_layernorm_available
@@ -332,12 +336,11 @@ class LayerNorm(MultiPlatformOp):
):
return layernorm(x, self.weight, self.bias, self.variance_epsilon)
else:
- return self.forward_native(x, **kwargs)
+ return self.forward_native(x)
def forward_native(
self,
x: torch.Tensor,
- **kwargs,
) -> torch.Tensor:
weight = self.weight if self.elementwise_affine else None
bias = self.bias if self.use_bias else None
@@ -354,28 +357,25 @@ class LayerNorm(MultiPlatformOp):
def forward_hip(
self,
x: torch.Tensor,
- **kwargs,
) -> torch.Tensor:
- return self.forward_native(x, **kwargs)
+ return self.forward_native(x)
def forward_npu(
self,
x: torch.Tensor,
- **kwargs,
) -> torch.Tensor:
- return self.forward_native(x, **kwargs)
+ return self.forward_native(x)
def forward_cpu(
self,
x: torch.Tensor,
- **kwargs,
) -> torch.Tensor:
if _is_cpu_amx_available:
return torch.ops.sgl_kernel.layernorm_cpu(
x, self.weight.data, self.variance_epsilon
)
else:
- return self.forward_native(x, **kwargs)
+ return self.forward_native(x)
class GemmaRMSNorm(MultiPlatformOp):
@@ -396,9 +396,11 @@ class GemmaRMSNorm(MultiPlatformOp):
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
- **kwargs,
+ post_residual_addition: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if residual is not None:
+ if post_residual_addition is not None:
+ residual = residual + post_residual_addition
gemma_fused_add_rmsnorm(
x, residual, self.weight.data, self.variance_epsilon
)
@@ -410,11 +412,13 @@ class GemmaRMSNorm(MultiPlatformOp):
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
- **kwargs,
+ post_residual_addition: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
orig_dtype = x.dtype
if residual is not None:
x = x + residual
+ if post_residual_addition is not None:
+ x = x + post_residual_addition
residual = x
x = x.float()
@@ -428,18 +432,20 @@ class GemmaRMSNorm(MultiPlatformOp):
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
- **kwargs,
+ post_residual_addition: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
- return self._forward_impl(x, residual, **kwargs)
+ return self._forward_impl(x, residual, post_residual_addition)
def forward_cpu(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
- **kwargs,
+ post_residual_addition: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if _is_cpu_amx_available:
if residual is not None:
+ if post_residual_addition is not None:
+ residual = residual + post_residual_addition
torch.ops.sgl_kernel.gemma_fused_add_rmsnorm_cpu(
x, residual, self.weight.data, self.variance_epsilon
)
@@ -447,16 +453,18 @@ class GemmaRMSNorm(MultiPlatformOp):
return torch.ops.sgl_kernel.gemma_rmsnorm_cpu(
x, self.weight.data, self.variance_epsilon
)
- return self.forward_native(x, residual, **kwargs)
+ return self.forward_native(x, residual, post_residual_addition)
def forward_npu(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
- **kwargs,
+ post_residual_addition: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if residual is not None:
x = x + residual
+ if post_residual_addition is not None:
+ x = x + post_residual_addition
residual = x
x, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon)
@@ -466,9 +474,9 @@ class GemmaRMSNorm(MultiPlatformOp):
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
- **kwargs,
+ post_residual_addition: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
- return self._forward_impl(x, residual, **kwargs)
+ return self._forward_impl(x, residual, post_residual_addition)
class Gemma3RMSNorm(MultiPlatformOp):
@@ -481,22 +489,22 @@ class Gemma3RMSNorm(MultiPlatformOp):
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
- def forward_native(self, x, **kwargs):
+ def forward_native(self, x):
output = self._norm(x.float())
# Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
output = output * (1.0 + self.weight.float())
return output.type_as(x)
- def forward_cpu(self, x, **kwargs):
+ def forward_cpu(self, x):
if _is_cpu_amx_available and x.stride(-1) == 1:
return torch.ops.sgl_kernel.gemma3_rmsnorm_cpu(x, self.weight, self.eps)
- return self.forward_native(x, **kwargs)
+ return self.forward_native(x)
- def forward_cuda(self, x, **kwargs):
- return self.forward_native(x, **kwargs)
+ def forward_cuda(self, x):
+ return self.forward_native(x)
- def forward_npu(self, x, **kwargs):
+ def forward_npu(self, x):
output, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.eps)
return output
diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py
index fa7431048..cd33ea735 100644
--- a/python/sglang/srt/layers/logits_processor.py
+++ b/python/sglang/srt/layers/logits_processor.py
@@ -878,11 +878,6 @@ class LogitsProcessor(nn.Module):
None, # bias
True, # is_vnni
)
- elif get_global_server_args().rl_on_policy_target is not None:
- # Due to tie-weight, we may not be able to change lm_head's weight dtype
- logits = torch.matmul(
- hidden_states.bfloat16(), lm_head.weight.T.bfloat16()
- )
else:
logits = torch.matmul(
hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
index a1885fade..14d692365 100644
--- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
@@ -14,6 +14,7 @@ import torch.nn.functional as F
import triton.language as tl
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
+from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import (
cpu_has_amx_support,
get_bool_env_var,
@@ -573,7 +574,10 @@ def fused_experts_impl(
).squeeze(dim=1)
else:
# According to micro benchmark results, torch.compile can get better performance for small token.
- if tokens_in_chunk <= 32:
+ if (
+ not get_global_server_args().enable_deterministic_inference
+ and tokens_in_chunk <= 32
+ ):
moe_sum_reduce_torch_compile(
intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx],
diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py
index 839463518..7948779aa 100644
--- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py
@@ -647,7 +647,7 @@ class FusedMoE(torch.nn.Module):
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod",
]
- )
+ ) and "zero" not in weight_name
else loaded_weight
)
diff --git a/python/sglang/srt/layers/moe/routed_experts_capturer.py b/python/sglang/srt/layers/moe/routed_experts_capturer.py
index 00bd68755..5a3ca8a67 100644
--- a/python/sglang/srt/layers/moe/routed_experts_capturer.py
+++ b/python/sglang/srt/layers/moe/routed_experts_capturer.py
@@ -1,5 +1,6 @@
import logging
from abc import ABC
+from contextlib import contextmanager
from typing import Optional
import numpy as np
@@ -8,13 +9,18 @@ import torch
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.layers.dp_attention import (
+ attn_tp_all_gather_into_tensor,
get_attention_dp_rank,
+ get_attention_tp_size,
get_dp_local_info,
is_dp_attention_enabled,
)
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import get_global_server_args
+from sglang.srt.layers.moe import (
+ get_moe_a2a_backend,
+)
logger = logging.getLogger(__name__)
@@ -181,13 +187,26 @@ class _RoutedExpertsCapturerReal(RoutedExpertsCapturer):
device=device,
)
+ if get_moe_a2a_backend().is_deepep():
+ attn_tp_size = get_attention_tp_size() if is_dp_attention_enabled() else 1
+ self.gather_buffer = torch.empty(
+ (
+ self.device_cache.buffer.shape[0] * attn_tp_size,
+ self.device_cache.buffer.shape[2],
+ ),
+ dtype=torch.int32,
+ device=device,
+ )
+
def _sync_fwd_experts_buffer_DtoH(
self,
forward_batch: ForwardBatch,
can_run_graph: bool,
cuda_graph_batch: int,
):
- if is_dp_attention_enabled():
+ # When DeepEP is enabled, capture() already does all_gather, so device_cache.buffer
+ # contains data from all DP ranks. We should not slice by DP rank in this case.
+ if is_dp_attention_enabled() and not get_moe_a2a_backend().is_deepep():
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
# handle with cuda graph padding
if can_run_graph:
@@ -206,6 +225,12 @@ class _RoutedExpertsCapturerReal(RoutedExpertsCapturer):
].cpu()
def capture(self, layer_id: int, topk_ids: torch.Tensor):
+ if get_moe_a2a_backend().is_deepep():
+ local_topk_ids = topk_ids
+ topk_ids = self.gather_buffer[
+ : local_topk_ids.size(0) * get_attention_tp_size()
+ ]
+ attn_tp_all_gather_into_tensor(topk_ids, local_topk_ids)
self.device_cache.capture_fwd_routed_experts(layer_id, topk_ids)
def get_routed_experts(
diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
index b4bdc41b3..3b895ff6a 100644
--- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
+++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
@@ -442,7 +442,7 @@ class CompressedTensorsConfig(QuantizationConfig):
)
is_static = not weight_quant.dynamic
- return is_channel_group and input_quant_none and is_symmetric and is_static
+ return is_channel_group and input_quant_none and is_static
def _get_scheme_from_parts(
self, weight_quant: BaseModel, input_quant: BaseModel
diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
index c5e5a11fc..c46526ecc 100644
--- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
+++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
@@ -30,7 +30,10 @@ from sglang.srt.layers.quantization.fp8_utils import (
normalize_e4m3fn_to_e4m3fnuz,
)
from sglang.srt.layers.quantization.gptq import gptq_marlin_moe_repack
-from sglang.srt.layers.quantization.marlin_utils import marlin_moe_permute_scales
+from sglang.srt.layers.quantization.marlin_utils import (
+ marlin_moe_permute_scales,
+ moe_awq_to_marlin_zero_points
+)
from sglang.srt.layers.quantization.utils import (
all_close_1d,
per_tensor_dequantize,
@@ -865,7 +868,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
self.strategy = config.strategy
self.group_size = config.group_size
self.actorder = config.actorder
- assert config.symmetric, "Only symmetric quantization is supported for MoE"
+ self.sym = config.symmetric
if not (
self.quant_config.quant_format == CompressionFormat.pack_quantized.value
@@ -920,7 +923,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
# In the case where we have actorder/g_idx,
# we do not partition the w2 scales
- load_full_w2 = self.actorder and self.group_size != -1
+ load_full_w2 = (self.actorder != 'static') and self.group_size != -1
if load_full_w2:
w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size
@@ -968,6 +971,32 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer.register_parameter("w13_weight_shape", w13_weight_shape)
set_weight_attrs(w13_weight_shape, extra_weight_attrs)
+ # add zero param
+ if not self.sym:
+ w13_qzeros = torch.nn.Parameter(
+ torch.empty(
+ num_experts,
+ num_groups_w13,
+ 2 * intermediate_size_per_partition // self.packed_factor,
+ dtype=torch.int32,
+ ),
+ requires_grad=False,
+ )
+ layer.register_parameter("w13_weight_zero_point", w13_qzeros)
+ set_weight_attrs(w13_qzeros, extra_weight_attrs)
+
+ w2_qzeros = torch.nn.Parameter(
+ torch.empty(
+ num_experts,
+ num_groups_w2,
+ hidden_size // self.packed_factor,
+ dtype=torch.int32,
+ ),
+ requires_grad=False,
+ )
+ layer.register_parameter("w2_weight_zero_point", w2_qzeros)
+ set_weight_attrs(w2_qzeros, extra_weight_attrs)
+
w13_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
@@ -1016,13 +1045,40 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer.a2_scale = None
layer.marlin_state = GPTQMarlinState.REPACK
+ if not hasattr(layer, "_original_shapes"):
+ layer._original_shapes = {}
+
+ # Force record: these are the target GPTQ shapes for rollback.
+ layer._original_shapes["w13_weight_packed"] = tuple(w13_weight.shape)
+ layer._original_shapes["w13_weight_scale"] = tuple(w13_scale.shape)
+ if not self.sym:
+ layer._original_shapes["w13_weight_zero_point"] = w13_qzeros.shape
+
+ layer._original_shapes["w2_weight_packed"] = tuple(w2_weight.shape)
+ layer._original_shapes["w2_weight_scale"] = tuple(w2_scale.shape)
+ if not self.sym:
+ layer._original_shapes["w2_weight_zero_point"] = tuple(w2_qzeros.shape)
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
+ # Skip if the layer is already converted to Marlin format to prevent double-packing.
+ if getattr(layer, "is_marlin_converted", False):
+ return
+
+ if not hasattr(layer, "_original_shapes"):
+ layer._original_shapes = {}
def replace_tensor(name, new_t):
+ target_attr = getattr(layer, name)
+
+ # Only save if the key doesn't exist to prevent overwriting with Marlin shapes.
+ if name not in layer._original_shapes:
+ # This is a safety check; `create_weights` usually handles this already.
+ layer._original_shapes[name] = tuple(target_attr.shape)
+
# It is important to use resize_() here since it ensures
# the same buffer is reused
- getattr(layer, name).resize_(new_t.shape)
- getattr(layer, name).copy_(new_t)
+ target_attr.resize_(new_t.shape)
+ target_attr.copy_(new_t)
del new_t
num_experts = layer.w13_weight_g_idx.shape[0]
@@ -1078,7 +1134,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer.w13_weight_packed.shape[2],
self.num_bits,
)
- replace_parameter(layer, "w13_weight_packed", marlin_w13_qweight)
+ replace_tensor("w13_weight_packed", marlin_w13_qweight)
marlin_w2_qweight = gptq_marlin_moe_repack(
layer.w2_weight_packed,
layer.w2_g_idx_sort_indices,
@@ -1086,7 +1142,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight_packed.shape[2],
self.num_bits,
)
- replace_parameter(layer, "w2_weight_packed", marlin_w2_qweight)
+ replace_tensor("w2_weight_packed", marlin_w2_qweight)
# Repack scales
marlin_w13_scales = marlin_moe_permute_scales(
layer.w13_weight_scale,
@@ -1094,7 +1150,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer.w13_weight_scale.shape[2],
self.group_size,
)
- replace_parameter(layer, "w13_weight_scale", marlin_w13_scales)
+ replace_tensor("w13_weight_scale", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales(
layer.w2_weight_scale,
@@ -1103,7 +1159,40 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight_scale.shape[2],
self.group_size,
)
- replace_parameter(layer, "w2_weight_scale", marlin_w2_scales)
+ replace_tensor("w2_weight_scale", marlin_w2_scales)
+
+ # Repack zero
+ if not self.sym:
+ marlin_w13_zp = moe_awq_to_marlin_zero_points(
+ layer.w13_weight_zero_point,
+ size_k=layer.w13_weight_zero_point.shape[1],
+ size_n=layer.w13_weight_zero_point.shape[2] * self.packed_factor,
+ num_bits=self.num_bits,
+ )
+ replace_tensor("w13_weight_zero_point", marlin_w13_zp)
+
+ marlin_w2_zp = moe_awq_to_marlin_zero_points(
+ layer.w2_weight_zero_point,
+ size_k=layer.w2_weight_zero_point.shape[1],
+ size_n=layer.w2_weight_zero_point.shape[2] * self.packed_factor,
+ num_bits=self.num_bits,
+ )
+ replace_tensor("w2_weight_zero_point", marlin_w2_zp)
+
+ layer.is_marlin_converted = True
+
+ def restore_weights_before_loading(self, layer: torch.nn.Module):
+ """Forcibly resize parameters back to their original shapes (e.g., GPTQ format) before loading weights."""
+ if not hasattr(layer, "_original_shapes"):
+ return
+
+ for name, orig_shape in layer._original_shapes.items():
+ param = getattr(layer, name, None)
+
+ if param is not None and param.shape != orig_shape:
+ param.resize_(orig_shape)
+
+ layer.is_marlin_converted = False
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
@@ -1154,6 +1243,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
g_idx2=layer.w2_weight_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
+ w1_zeros=layer.w13_weight_zero_point if not self.sym else None,
+ w2_zeros=layer.w2_weight_zero_point if not self.sym else None,
num_bits=self.num_bits,
is_k_full=self.is_k_full,
routed_scaling_factor=self.moe_runner_config.routed_scaling_factor,
diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py
index 480579e01..dd8ca7d4f 100644
--- a/python/sglang/srt/layers/rotary_embedding.py
+++ b/python/sglang/srt/layers/rotary_embedding.py
@@ -136,9 +136,7 @@ class RotaryEmbedding(MultiPlatformOp):
if get_global_server_args().rl_on_policy_target is not None:
self._forward_method = self.forward_native
- self._apply_rotary_emb_wrapped = torch.compile(dynamic=True)(
- self._apply_rotary_emb_wrapped
- )
+
self.position_cos, self.position_sin = None, None
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
@@ -1578,6 +1576,9 @@ class MRotaryEmbedding(RotaryEmbedding):
key: torch.Tensor,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
+ assert (
+ fused_set_kv_buffer_arg is None
+ ), "fused_set_kv_buffer_arg is not supported for npu implementation"
# TODO: remove this when npu_mrope supports QNumHeads * QHeadSize > 4096
assert (
fused_set_kv_buffer_arg is None
diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py
index 55bef5652..35ad68b1c 100644
--- a/python/sglang/srt/layers/sampler.py
+++ b/python/sglang/srt/layers/sampler.py
@@ -108,16 +108,11 @@ class Sampler(nn.Module):
if return_logprob and SGLANG_RETURN_ORIGINAL_LOGPROB:
probs_without_temp_scaling = torch.softmax(logits, dim=-1)
- if get_global_server_args().rl_on_policy_target is not None:
- logits_div_temperature = (
- logits.bfloat16().div(sampling_info.temperatures).bfloat16()
- )
- logprobs_via_logsoftmax_kernel = torch.log_softmax(
- logits_div_temperature, dim=-1
- )
-
# Post process logits
logits.div_(sampling_info.temperatures)
+ if get_global_server_args().rl_on_policy_target is not None:
+ logprobs_via_logsoftmax_kernel = torch.log_softmax(logits, dim=-1)
+
# For ascend backend, softmax is not needed before sampling
if not get_global_server_args().sampling_backend == "ascend" or (
return_logprob and not SGLANG_RETURN_ORIGINAL_LOGPROB
diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py
index 2ecd8542f..2a2e011ea 100644
--- a/python/sglang/srt/managers/io_struct.py
+++ b/python/sglang/srt/managers/io_struct.py
@@ -1292,6 +1292,19 @@ class UpdateWeightsFromIPCReqOutput(BaseReq):
success: bool
message: str
+@dataclass
+class PostProcessWeightsReqInput(BaseReq):
+ # Whether to restore weights before loading new weights
+ restore_weights_before_load: bool = False
+ # Whether to enable quantization post-processing
+ post_process_quantization: bool = False
+
+
+@dataclass
+class PostProcessWeightsReqOutput(BaseReq):
+ success: bool
+ message: str
+
@dataclass
class InitWeightsSendGroupForRemoteInstanceReqOutput(BaseReq):
@@ -1667,6 +1680,10 @@ class GetLoadReqOutput(BaseReq):
num_waiting_reqs: int
num_tokens: int
ts_tic: float
+ # Per-queue breakdown: list of {name, num_reqs, num_tokens, reqs: [{rid, seqlen, input_len, output_len}]}
+ queue_details: Optional[List[Dict[str, Any]]] = None
+ # Running batch info
+ running_details: Optional[Dict[str, Any]] = None
@dataclass
diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py
index d423e61d7..d1f54a832 100644
--- a/python/sglang/srt/managers/schedule_batch.py
+++ b/python/sglang/srt/managers/schedule_batch.py
@@ -1779,7 +1779,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
selected_indices=sorted_indices, buf_multiplier=buf_multiplier
)
):
- if len(sorted_indices) == 1:
+ # We should allow all requests to be retracted in decode disaggregation mode
+ # because there call be prealloc prefill requests.
+ num_minimum_reqs = 0 if server_args.disaggregation_mode == "decode" else 1
+ if len(sorted_indices) == num_minimum_reqs:
# Always keep at least one request
break
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
index 92d286897..43bfab691 100644
--- a/python/sglang/srt/managers/scheduler.py
+++ b/python/sglang/srt/managers/scheduler.py
@@ -98,6 +98,7 @@ from sglang.srt.managers.io_struct import (
OpenSessionReqInput,
OpenSessionReqOutput,
PauseGenerationReqInput,
+ PostProcessWeightsReqInput,
ProfileReq,
ReleaseMemoryOccupationReqInput,
ResumeMemoryOccupationReqInput,
@@ -1060,6 +1061,7 @@ class Scheduler(
),
(UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
(UpdateWeightsFromIPCReqInput, self.update_weights_from_ipc),
+ (PostProcessWeightsReqInput, self.post_process_weights),
(GetWeightsByNameReqInput, self.get_weights_by_name),
(ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
diff --git a/python/sglang/srt/managers/scheduler_metrics_mixin.py b/python/sglang/srt/managers/scheduler_metrics_mixin.py
index d44ff6027..3fad54598 100644
--- a/python/sglang/srt/managers/scheduler_metrics_mixin.py
+++ b/python/sglang/srt/managers/scheduler_metrics_mixin.py
@@ -553,12 +553,48 @@ class SchedulerMetricsMixin:
num_tokens += sum(req.seqlen for queue in waiting_queues for req in queue)
num_waiting_reqs = sum(len(queue) for queue in waiting_queues)
+ # Collect per-queue details
+ queue_names = ["waiting_queue"]
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
+ queue_names.append("bootstrap_queue")
+ elif self.disaggregation_mode == DisaggregationMode.DECODE:
+ queue_names.append("prealloc_queue")
+ queue_names.append("transfer_queue")
+ queue_names.append("retracted_queue")
+
+ queue_details = []
+ for name, queue in zip(queue_names, waiting_queues):
+ reqs_info = []
+ for req in queue:
+ reqs_info.append({
+ "seqlen": req.seqlen,
+ })
+ queue_details.append({
+ "name": name,
+ "num_reqs": len(queue),
+ "num_tokens": sum(r["seqlen"] for r in reqs_info),
+ "reqs": reqs_info,
+ })
+
+ # Collect running batch details
+ running_reqs_info = []
+ for req in self.running_batch.reqs:
+ running_reqs_info.append({
+ "seqlen": req.seqlen,
+ })
+ running_details = {
+ "num_reqs": len(self.running_batch.reqs),
+ "reqs": running_reqs_info,
+ }
+
return GetLoadReqOutput(
dp_rank=self.dp_rank,
num_reqs=len(self.running_batch.reqs) + num_waiting_reqs,
num_waiting_reqs=num_waiting_reqs,
num_tokens=num_tokens,
ts_tic=time.perf_counter(),
+ queue_details=queue_details,
+ running_details=running_details,
)
@contextmanager
diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py
index e40586c24..243e2b0c2 100644
--- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py
+++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py
@@ -10,6 +10,7 @@ from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.environ import envs
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.routed_experts_capturer import get_global_experts_capturer
+
from sglang.srt.managers.io_struct import (
AbortReq,
BatchEmbeddingOutput,
@@ -1070,7 +1071,7 @@ class SchedulerOutputProcessorMixin:
req.log_time_stats()
# Send to detokenizer
- if reqs or is_idle_batch:
+ if rids or is_idle_batch:
if self.model_config.is_multimodal_gen:
return
diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py
index 293a84350..8ee36c794 100644
--- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py
+++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py
@@ -1,6 +1,7 @@
from __future__ import annotations
import logging
+import os
import traceback
from typing import TYPE_CHECKING, Tuple
@@ -12,6 +13,9 @@ from sglang.srt.constants import (
GPU_MEMORY_TYPE_KV_CACHE,
GPU_MEMORY_TYPE_WEIGHTS,
)
+from sglang.srt.disaggregation.utils import DisaggregationMode
+from sglang.srt.distributed import get_moe_ep_group, get_moe_tp_group, get_tp_group
+from sglang.srt.layers.dp_attention import get_attention_tp_group
from sglang.srt.managers.io_struct import (
CheckWeightsReqInput,
CheckWeightsReqOutput,
@@ -21,6 +25,8 @@ from sglang.srt.managers.io_struct import (
GetWeightsByNameReqOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
+ PostProcessWeightsReqInput,
+ PostProcessWeightsReqOutput,
ReleaseMemoryOccupationReqInput,
ReleaseMemoryOccupationReqOutput,
ResumeMemoryOccupationReqInput,
@@ -114,6 +120,11 @@ class SchedulerUpdateWeightsMixin:
torch.distributed.barrier(group=self.tp_cpu_group)
return UpdateWeightsFromIPCReqOutput(success, message)
+ def post_process_weights(self, recv_req: PostProcessWeightsReqInput):
+ """Optional post-processing for updated weights (e.g., Marlin conversion)."""
+ success, message = self.tp_worker.post_process_weights(recv_req)
+ return PostProcessWeightsReqOutput(success, message)
+
def get_weights_by_name(self: Scheduler, recv_req: GetWeightsByNameReqInput):
parameter = self.tp_worker.get_weights_by_name(recv_req)
return GetWeightsByNameReqOutput(parameter)
@@ -137,6 +148,15 @@ class SchedulerUpdateWeightsMixin:
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
self.flush_cache()
+ if self.disaggregation_mode == DisaggregationMode.DECODE:
+ if hasattr(self, "disagg_decode_transfer_queue"):
+ self.disagg_decode_transfer_queue.release_memory_occupation()
+ if hasattr(self, "disagg_decode_prealloc_queue"):
+ self.disagg_decode_prealloc_queue.release_memory_occupation()
+ elif self.disaggregation_mode == DisaggregationMode.PREFILL:
+ if hasattr(self, "disagg_prefill_bootstrap_queue"):
+ self.disagg_prefill_bootstrap_queue.release_memory_occupation()
+
if GPU_MEMORY_TYPE_WEIGHTS in tags:
self.stashed_model_static_state = _export_static_state(
self.tp_worker.model_runner.model
@@ -177,6 +197,15 @@ class SchedulerUpdateWeightsMixin:
if GPU_MEMORY_TYPE_KV_CACHE in tags:
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE)
+ if self.disaggregation_mode == DisaggregationMode.DECODE:
+ if hasattr(self, "disagg_decode_transfer_queue"):
+ self.disagg_decode_transfer_queue.resume_memory_occupation()
+ if hasattr(self, "disagg_decode_prealloc_queue"):
+ self.disagg_decode_prealloc_queue.resume_memory_occupation()
+ elif self.disaggregation_mode == DisaggregationMode.PREFILL:
+ if hasattr(self, "disagg_prefill_bootstrap_queue"):
+ self.disagg_prefill_bootstrap_queue.resume_memory_occupation()
+
return ResumeMemoryOccupationReqOutput()
def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput):
diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py
index e5d42bed8..412293b30 100644
--- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py
+++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py
@@ -49,6 +49,8 @@ from sglang.srt.managers.io_struct import (
LoadLoRAAdapterReqOutput,
LoRAUpdateOutput,
OpenSessionReqInput,
+ PostProcessWeightsReqInput,
+ PostProcessWeightsReqOutput,
ProfileReq,
ProfileReqOutput,
ProfileReqType,
@@ -177,6 +179,9 @@ class TokenizerCommunicatorMixin:
self.update_weights_from_ipc_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
+ self.post_process_weights_communicator = _Communicator(
+ self.send_to_scheduler, server_args.dp_size
+ )
self.get_weights_by_name_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
@@ -250,6 +255,10 @@ class TokenizerCommunicatorMixin:
UpdateWeightsFromIPCReqOutput,
self.update_weights_from_ipc_communicator.handle_recv,
),
+ (
+ PostProcessWeightsReqOutput,
+ self.post_process_weights_communicator.handle_recv,
+ ),
(
GetWeightsByNameReqOutput,
self.get_weights_by_name_communicator.handle_recv,
@@ -433,6 +442,17 @@ class TokenizerCommunicatorMixin:
return success, message
+ async def post_process_weights(
+ self: TokenizerManager,
+ obj: PostProcessWeightsReqInput,
+ request: Optional[fastapi.Request] = None,
+ ) -> Tuple[bool, str]:
+ """Trigger post-processing hooks for weights after loading (e.g., Marlin conversion)."""
+ self.auto_create_handle_loop()
+ async with self.model_update_lock.writer_lock:
+ results = await self.post_process_weights_communicator(obj)
+ return _Communicator.merge_results(results)
+
async def init_weights_send_group_for_remote_instance(
self,
obj: InitWeightsSendGroupForRemoteInstanceReqInput,
diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py
index 49f63a198..e4cd0ff2b 100644
--- a/python/sglang/srt/managers/tp_worker.py
+++ b/python/sglang/srt/managers/tp_worker.py
@@ -27,6 +27,7 @@ from sglang.srt.managers.io_struct import (
InitWeightsSendGroupForRemoteInstanceReqInput,
InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput,
+ PostProcessWeightsReqInput,
SendWeightsToRemoteInstanceReqInput,
UnloadLoRAAdapterReqInput,
UpdateWeightFromDiskReqInput,
@@ -175,6 +176,11 @@ class BaseTpWorker(ABC):
success, message = self.model_runner.update_weights_from_ipc(recv_req)
return success, message
+ def post_process_weights(self, recv_req: PostProcessWeightsReqInput):
+ """Perform optional post-processing on the updated model weights (e.g., Marlin conversion)."""
+ success, message = self.model_runner.post_process_weights(recv_req)
+ return success, message
+
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.model_runner.get_weights_by_name(
recv_req.name, recv_req.truncate_size
diff --git a/python/sglang/srt/mem_cache/allocator.py b/python/sglang/srt/mem_cache/allocator.py
index eaf29628b..bf74cbd12 100644
--- a/python/sglang/srt/mem_cache/allocator.py
+++ b/python/sglang/srt/mem_cache/allocator.py
@@ -287,6 +287,85 @@ def alloc_decode_kernel(
tl.store(out_indices + pid, page * page_size)
+def alloc_extend_torch_fallback(
+ prefix_lens_cpu: torch.Tensor,
+ seq_lens_cpu: torch.Tensor,
+ last_loc: torch.Tensor,
+ free_pages: torch.Tensor,
+ out_indices: torch.Tensor,
+ page_size: int,
+ debug_mode: bool = False,
+):
+ extend_lens_cpu = (seq_lens_cpu - prefix_lens_cpu).to(torch.int64)
+ if extend_lens_cpu.numel() == 0:
+ return
+
+ output_start_locs_cpu = torch.cumsum(extend_lens_cpu, dim=0) - extend_lens_cpu
+ num_pages_after = (seq_lens_cpu + page_size - 1) // page_size
+ num_pages_before = (prefix_lens_cpu + page_size - 1) // page_size
+ num_new_pages_cpu = num_pages_after - num_pages_before
+ page_start_locs_cpu = torch.cumsum(num_new_pages_cpu, dim=0) - num_new_pages_cpu
+
+ total_new_pages = int(num_new_pages_cpu.sum().item())
+ if total_new_pages > free_pages.numel():
+ return
+
+ if debug_mode:
+ assert int(extend_lens_cpu.sum().item()) == out_indices.numel()
+
+ prefix_lens_list = prefix_lens_cpu.tolist()
+ seq_lens_list = seq_lens_cpu.tolist()
+ extend_lens_list = extend_lens_cpu.tolist()
+ out_start_list = output_start_locs_cpu.tolist()
+ page_start_list = page_start_locs_cpu.tolist()
+ num_new_pages_list = num_new_pages_cpu.tolist()
+
+ device = out_indices.device
+ dtype = out_indices.dtype
+ offsets_page = torch.arange(page_size, device=device, dtype=dtype)
+
+ for i, extend_len in enumerate(extend_lens_list):
+ if extend_len == 0:
+ continue
+
+ pre_len = prefix_lens_list[i]
+ seq_len = seq_lens_list[i]
+ out_start = out_start_list[i]
+ page_start = page_start_list[i]
+ num_new_pages = num_new_pages_list[i]
+
+ pre_mod = pre_len % page_size
+ part1 = min(extend_len, page_size - pre_mod) if pre_mod != 0 else 0
+ if part1:
+ start_val = last_loc[i] + 1
+ out_indices[out_start : out_start + part1] = start_val + torch.arange(
+ part1, device=device, dtype=dtype
+ )
+ if part1 == extend_len:
+ continue
+
+ ceil_pre_pages = (pre_len + page_size - 1) // page_size
+ full_pages_after = seq_len // page_size
+ num_full_pages = full_pages_after - ceil_pre_pages
+ if num_full_pages < 0:
+ num_full_pages = 0
+ part2 = num_full_pages * page_size
+ if part2:
+ pages = free_pages[page_start : page_start + num_full_pages]
+ full_indices = (pages[:, None] * page_size + offsets_page).reshape(-1)
+ out_indices[out_start + part1 : out_start + part1 + part2] = full_indices
+ if part1 + part2 == extend_len:
+ continue
+
+ part3 = extend_len - part1 - part2
+ if part3:
+ last_page = free_pages[page_start + num_new_pages - 1]
+ out_indices[out_start + part1 + part2 : out_start + extend_len] = (
+ last_page * page_size
+ + torch.arange(part3, device=device, dtype=dtype)
+ )
+
+
class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
"""
An allocator managing the indices to kv cache data.
@@ -349,11 +428,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
)
- self.seen_max_num_extend_tokens_next_power_of_2 = max(
- self.seen_max_num_extend_tokens_next_power_of_2,
- next_power_of_2(extend_num_tokens),
- )
-
bs = len(prefix_lens)
if self.need_sort and extend_num_tokens // self.page_size + bs + 1 > len(
self.free_pages
@@ -363,16 +437,34 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int64, device=self.device
)
- alloc_extend_kernel[(bs,)](
- prefix_lens,
- seq_lens,
- last_loc,
- self.free_pages,
- out_indices,
- next_power_of_2(bs),
- self.page_size,
- self.seen_max_num_extend_tokens_next_power_of_2,
- )
+
+ # Use PyTorch fallback for large extend_num_tokens to avoid slow Triton compilation
+ MAX_TRITON_EXTEND_TOKENS = 65536 # 64K
+ if next_power_of_2(extend_num_tokens) > MAX_TRITON_EXTEND_TOKENS:
+ alloc_extend_torch_fallback(
+ prefix_lens_cpu=prefix_lens_cpu,
+ seq_lens_cpu=seq_lens_cpu,
+ last_loc=last_loc,
+ free_pages=self.free_pages,
+ out_indices=out_indices,
+ page_size=self.page_size,
+ debug_mode=self.debug_mode,
+ )
+ else:
+ self.seen_max_num_extend_tokens_next_power_of_2 = max(
+ self.seen_max_num_extend_tokens_next_power_of_2,
+ next_power_of_2(extend_num_tokens),
+ )
+ alloc_extend_kernel[(bs,)](
+ prefix_lens,
+ seq_lens,
+ last_loc,
+ self.free_pages,
+ out_indices,
+ next_power_of_2(bs),
+ self.page_size,
+ self.seen_max_num_extend_tokens_next_power_of_2,
+ )
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py
index f6cfca8b6..5d3cad059 100644
--- a/python/sglang/srt/mem_cache/hiradix_cache.py
+++ b/python/sglang/srt/mem_cache/hiradix_cache.py
@@ -11,10 +11,15 @@ import torch
from sglang.srt.managers.cache_controller import HiCacheController, PrefetchOperation
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
-from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
+from sglang.srt.mem_cache.memory_pool import (
+ MHATokenToKVPool,
+ MLATokenToKVPool,
+ NSATokenToKVPool,
+)
from sglang.srt.mem_cache.memory_pool_host import (
MHATokenToKVPoolHost,
MLATokenToKVPoolHost,
+ NSATokenToKVPoolHost,
)
from sglang.srt.mem_cache.radix_cache import (
RadixCache,
@@ -54,6 +59,16 @@ class HiRadixCache(RadixCache):
server_args.hicache_mem_layout,
allocator_type=server_args.hicache_storage_backend,
)
+ elif isinstance(self.kv_cache, NSATokenToKVPool):
+ # Check NSA before MLA since NSATokenToKVPool is a subclass of MLATokenToKVPool
+ self.token_to_kv_pool_host = NSATokenToKVPoolHost(
+ self.kv_cache,
+ server_args.hicache_ratio,
+ server_args.hicache_size,
+ self.page_size,
+ server_args.hicache_mem_layout,
+ allocator_type=server_args.hicache_storage_backend,
+ )
elif isinstance(self.kv_cache, MLATokenToKVPool):
self.token_to_kv_pool_host = MLATokenToKVPoolHost(
self.kv_cache,
@@ -64,7 +79,7 @@ class HiRadixCache(RadixCache):
allocator_type=server_args.hicache_storage_backend,
)
else:
- raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
+ raise ValueError(f"HiRadixCache only supports MHA and MLA and NSA yet")
self.tp_group = params.tp_cache_group
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py
index 65d562a27..fe5547d7b 100644
--- a/python/sglang/srt/mem_cache/memory_pool.py
+++ b/python/sglang/srt/mem_cache/memory_pool.py
@@ -1678,7 +1678,8 @@ class NSATokenToKVPool(MLATokenToKVPool):
with (
torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.custom_mem_pool
- else nullcontext()
+ else nullcontext(),
+ self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE),
):
self.index_k_with_scale_buffer = [
torch.zeros(
@@ -1700,6 +1701,11 @@ class NSATokenToKVPool(MLATokenToKVPool):
)
for _ in range(layer_num)
]
+ self.index_k_with_scale_buffer_ptrs = torch.tensor(
+ [x.data_ptr() for x in self.index_k_with_scale_buffer],
+ dtype=torch.uint64,
+ device=self.device,
+ )
self._finalize_allocation_log(size)
def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor:
@@ -1775,6 +1781,50 @@ class NSATokenToKVPool(MLATokenToKVPool):
]
return data_ptrs, data_lens, item_lens
+ def get_cpu_copy(self, indices):
+ # First, save the kv_buffer (inherited from MLATokenToKVPool)
+ kv_cache_cpu = super().get_cpu_copy(indices)
+
+ # Additionally, save the index_k_with_scale_buffer (page-indexed)
+ page_indices = indices[:: self.page_size] // self.page_size
+ torch.cuda.synchronize()
+ index_k_cpu = []
+ chunk_size = self.cpu_offloading_chunk_size
+ # Convert chunk_size from token-level to page-level
+ page_chunk_size = max(1, chunk_size // self.page_size)
+ for layer_id in range(self.layer_num):
+ index_k_cpu.append([])
+ for i in range(0, len(page_indices), page_chunk_size):
+ chunk_page_indices = page_indices[i : i + page_chunk_size]
+ idx_cpu = self.index_k_with_scale_buffer[layer_id][
+ chunk_page_indices
+ ].to("cpu", non_blocking=True)
+ index_k_cpu[-1].append(idx_cpu)
+ torch.cuda.synchronize()
+
+ return {"kv": kv_cache_cpu, "index_k": index_k_cpu}
+
+ def load_cpu_copy(self, kv_cache_cpu_dict, indices):
+ # Restore the kv_buffer (inherited from MLATokenToKVPool)
+ super().load_cpu_copy(kv_cache_cpu_dict["kv"], indices)
+
+ # Restore the index_k_with_scale_buffer (page-indexed)
+ page_indices = indices[:: self.page_size] // self.page_size
+ index_k_cpu = kv_cache_cpu_dict["index_k"]
+ torch.cuda.synchronize()
+ chunk_size = self.cpu_offloading_chunk_size
+ page_chunk_size = max(1, chunk_size // self.page_size)
+ for layer_id in range(self.layer_num):
+ for i in range(0, len(page_indices), page_chunk_size):
+ chunk_page_indices = page_indices[i : i + page_chunk_size]
+ idx_cpu = index_k_cpu[layer_id][i // page_chunk_size]
+ assert idx_cpu.shape[0] == len(chunk_page_indices)
+ idx_chunk = idx_cpu.to(
+ self.index_k_with_scale_buffer[0].device, non_blocking=True
+ )
+ self.index_k_with_scale_buffer[layer_id][chunk_page_indices] = idx_chunk
+ torch.cuda.synchronize()
+
def get_kv_size_bytes(self):
kv_size_bytes = super().get_kv_size_bytes()
for index_k_cache in self.index_k_with_scale_buffer:
diff --git a/python/sglang/srt/mem_cache/memory_pool_host.py b/python/sglang/srt/mem_cache/memory_pool_host.py
index 46394158f..d99cc4b3b 100644
--- a/python/sglang/srt/mem_cache/memory_pool_host.py
+++ b/python/sglang/srt/mem_cache/memory_pool_host.py
@@ -15,7 +15,12 @@ from sglang.jit_kernel.hicache import (
from sglang.jit_kernel.hicache import (
transfer_hicache_one_layer as jit_transfer_hicache_one_layer,
)
-from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
+from sglang.srt.mem_cache.memory_pool import (
+ KVCache,
+ MHATokenToKVPool,
+ MLATokenToKVPool,
+ NSATokenToKVPool,
+)
from sglang.srt.utils import is_cuda, is_npu, is_xpu
_is_cuda = is_cuda()
@@ -1015,3 +1020,199 @@ class MLATokenToKVPoolHost(HostKVCache):
else:
raise ValueError(f"Unsupported layout: {self.layout}")
return ptr_list, element_size_list
+
+
+class NSATokenToKVPoolHost(MLATokenToKVPoolHost):
+ """
+ Host memory pool for NSA (Native Sparse Attention) KV cache.
+
+ NSA extends MLA with an additional index_k_with_scale_buffer that stores
+ sparse attention indexing information. This class ensures that buffer is
+ also backed up and restored during hicache operations.
+ """
+
+ device_pool: NSATokenToKVPool
+
+ def __init__(
+ self,
+ device_pool: NSATokenToKVPool,
+ host_to_device_ratio: float,
+ host_size: int,
+ page_size: int,
+ layout: str,
+ pin_memory: bool = True,
+ device: str = "cpu",
+ allocator_type: str = "default",
+ ):
+ # Store NSA-specific attributes before calling parent __init__
+ self.index_head_dim = device_pool.index_head_dim
+ self.quant_block_size = device_pool.quant_block_size
+ self.index_k_with_scale_buffer_dtype = device_pool.index_k_with_scale_buffer_dtype
+
+ super().__init__(
+ device_pool,
+ host_to_device_ratio,
+ host_size,
+ page_size,
+ layout,
+ pin_memory,
+ device,
+ allocator_type,
+ )
+
+ # Initialize index buffer references and pointers for efficient transfer
+ self.index_data_refs = [
+ self.index_k_with_scale_buffer[i] for i in range(self.layer_num)
+ ]
+ self.index_data_ptrs = torch.tensor(
+ [x.data_ptr() for x in self.index_data_refs],
+ dtype=torch.uint64,
+ device=self.device_pool.device,
+ )
+
+ def get_size_per_token(self):
+ # Get base MLA size
+ base_size = super().get_size_per_token()
+
+ # Add NSA index buffer size per token
+ # index_k_with_scale_buffer shape per layer: (num_pages, page_size * (index_head_dim + index_head_dim // quant_block_size * 4))
+ # Per token: (index_head_dim + index_head_dim // quant_block_size * 4) * dtype.itemsize * layer_num
+ index_size_per_token = (
+ (self.index_head_dim + self.index_head_dim // self.quant_block_size * 4)
+ * self.index_k_with_scale_buffer_dtype.itemsize
+ * self.layer_num
+ )
+
+ return base_size + index_size_per_token
+
+ def init_kv_buffer(self):
+ # Initialize base MLA kv_buffer
+ buffer = super().init_kv_buffer()
+
+ # Initialize NSA index_k_with_scale_buffer on host
+ # Layout matches device pool: (num_pages, page_size * (index_head_dim + index_head_dim // quant_block_size * 4))
+ index_buffer_second_dim = self.page_size * (
+ self.index_head_dim + self.index_head_dim // self.quant_block_size * 4
+ )
+ self.index_stride_size = (self.index_head_dim + self.index_head_dim // self.quant_block_size * 4) * self.index_k_with_scale_buffer_dtype.itemsize
+
+ alloc_func = ALLOC_MEMORY_FUNCS[self.device_pool.device]
+ self.index_k_with_scale_buffer = [
+ alloc_func(
+ (self.page_num, index_buffer_second_dim),
+ dtype=self.index_k_with_scale_buffer_dtype,
+ device=self.device,
+ pin_memory=self.pin_memory,
+ allocator=self.allocator,
+ )
+ for _ in range(self.layer_num)
+ ]
+
+ return buffer
+
+ def _load_indexer_to_device_per_layer(
+ self, device_pool, host_indices, device_indices, layer_id, io_backend
+ ):
+ """Load index_k_with_scale_buffer from host to device for a specific layer."""
+ # Convert token indices to page indices
+ # host_indices and device_indices are token-level indices
+ # index_k_with_scale_buffer is page-level with shape (num_pages, page_size * dim)
+
+ if io_backend == "kernel":
+ # Use page-level copy for index buffer
+ # Calculate page indices from token indices
+ page_indices_host = host_indices[:: self.page_size] // self.page_size
+ page_indices_device = device_indices[:: self.page_size] // self.page_size
+
+ src_buffer = self.index_k_with_scale_buffer[layer_id]
+ dst_buffer = device_pool.index_k_with_scale_buffer[layer_id - device_pool.start_layer]
+
+ # Copy each page
+ # for i in range(len(page_indices_host)):
+ # src_page_idx = page_indices_host[i].item()
+ # dst_page_idx = page_indices_device[i].item()
+ # dst_buffer[dst_page_idx].copy_(src_buffer[src_page_idx], non_blocking=True)
+ if self.layout == "layer_first":
+ transfer_kv_per_layer_mla(
+ src=src_buffer,
+ dst=dst_buffer,
+ src_indices=page_indices_host,
+ dst_indices=page_indices_device,
+ item_size=self.index_stride_size * self.page_size,
+ )
+ else:
+ raise ValueError(f"Unsupported layout: {self.layout}")
+
+ elif io_backend == "direct":
+ # Direct I/O copy for index buffer
+ page_indices_host = host_indices[:: self.page_size] // self.page_size
+ page_indices_device = device_indices[:: self.page_size] // self.page_size
+
+ src_buffer = self.index_k_with_scale_buffer[layer_id]
+ dst_buffer = device_pool.index_k_with_scale_buffer[layer_id - device_pool.start_layer]
+
+ for i in range(len(page_indices_host)):
+ src_page_idx = page_indices_host[i].item()
+ dst_page_idx = page_indices_device[i].item()
+ dst_buffer[dst_page_idx].copy_(src_buffer[src_page_idx], non_blocking=True)
+ else:
+ raise ValueError(f"Unsupported IO backend for NSA indexer: {io_backend}")
+
+ def _backup_indexer_from_device_all_layer(
+ self, device_pool, host_indices, device_indices, io_backend
+ ):
+ """Backup index_k_with_scale_buffer from device to host for all layers."""
+ # Convert token indices to page indices
+ page_indices_host = host_indices[:: self.page_size] // self.page_size
+ page_indices_device = device_indices[:: self.page_size] // self.page_size
+
+ # if io_backend in ["kernel", "direct"]:
+ if io_backend == "kernel":
+ if self.layout == "layer_first":
+ transfer_kv_all_layer_mla(
+ src_layers=device_pool.index_k_with_scale_buffer_ptrs,
+ dst_layers=self.index_data_ptrs,
+ src_indices=page_indices_device,
+ dst_indices=page_indices_host,
+ item_size=self.index_stride_size * self.page_size,
+ num_layers=self.layer_num,
+ )
+ else:
+ raise ValueError(f"Unsupported layout: {self.layout}")
+ elif io_backend == "direct":
+ for layer_id in range(self.layer_num):
+ src_buffer = device_pool.index_k_with_scale_buffer[layer_id]
+ dst_buffer = self.index_k_with_scale_buffer[layer_id]
+
+ for i in range(len(page_indices_device)):
+ src_page_idx = page_indices_device[i].item()
+ dst_page_idx = page_indices_host[i].item()
+ dst_buffer[dst_page_idx].copy_(src_buffer[src_page_idx], non_blocking=True)
+ else:
+ raise ValueError(f"Unsupported IO backend for NSA indexer: {io_backend}")
+
+ def load_to_device_per_layer(
+ self, device_pool, host_indices, device_indices, layer_id, io_backend
+ ):
+ """Load KV cache and index buffer from host to device for a specific layer."""
+ # Load base MLA kv_buffer
+ super().load_to_device_per_layer(
+ device_pool, host_indices, device_indices, layer_id, io_backend
+ )
+ # Load NSA index_k_with_scale_buffer
+ self._load_indexer_to_device_per_layer(
+ device_pool, host_indices, device_indices, layer_id, io_backend
+ )
+
+ def backup_from_device_all_layer(
+ self, device_pool, host_indices, device_indices, io_backend
+ ):
+ """Backup KV cache and index buffer from device to host for all layers."""
+ # Backup base MLA kv_buffer
+ super().backup_from_device_all_layer(
+ device_pool, host_indices, device_indices, io_backend
+ )
+ # Backup NSA index_k_with_scale_buffer
+ self._backup_indexer_from_device_all_layer(
+ device_pool, host_indices, device_indices, io_backend
+ )
diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py
index 1d69c0582..d984c2e12 100644
--- a/python/sglang/srt/model_executor/model_runner.py
+++ b/python/sglang/srt/model_executor/model_runner.py
@@ -558,7 +558,8 @@ class ModelRunner(ModelRunnerKVCacheMixin):
)
# Init routed experts capturer
- self.init_routed_experts_capturer()
+ if not self.is_draft_worker:
+ self.init_routed_experts_capturer()
if self.device == "cuda":
self.init_cublas()
@@ -2224,11 +2225,19 @@ class ModelRunner(ModelRunnerKVCacheMixin):
output.expert_distribution_metrics = recorder_outputs.get("metrics")
# Copy cached routing experts' buffers back to CPU cache
- get_global_experts_capturer().on_forward_end(
- forward_batch=forward_batch,
- can_run_graph=output.can_run_graph,
- cuda_graph_batch=getattr(self.graph_runner, "bs", None),
- )
+ if not self.is_draft_worker:
+ # In speculative decoding, num_tokens_per_bs > 1, so we need to pass
+ # the actual number of tokens per dp rank in cuda graph, not batch size.
+ cuda_graph_num_tokens = None
+ if getattr(self.graph_runner, "bs", None):
+ cuda_graph_num_tokens = (
+ self.graph_runner.bs * self.graph_runner.num_tokens_per_bs
+ )
+ get_global_experts_capturer().on_forward_end(
+ forward_batch=forward_batch,
+ can_run_graph=output.can_run_graph,
+ cuda_graph_batch=cuda_graph_num_tokens,
+ )
if self.eplb_manager is not None:
self.eplb_manager.on_forward_pass_end()
@@ -2436,6 +2445,41 @@ class ModelRunner(ModelRunnerKVCacheMixin):
logger.error(f"IPC weight update failed: {e}")
return False, str(e)
+ def post_process_weights(self, recv_req):
+ """
+ Execute post-processing logic for model weights, such as Marlin quantization format conversion.
+ """
+ from sglang.srt.model_loader.loader import device_loading_context
+
+ target_device = torch.device("cuda", torch.cuda.current_device())
+
+ if recv_req.restore_weights_before_load:
+ for _, module in self.model.named_modules():
+ quant_method = getattr(module, "quant_method", None)
+
+ # Check if the module supports restoring weights
+ if quant_method is not None and hasattr(
+ quant_method, "restore_weights_before_loading"
+ ):
+
+ with device_loading_context(module, target_device):
+ quant_method.restore_weights_before_loading(module)
+
+ if recv_req.post_process_quantization:
+ # Iterate through all modules to apply specific post-loading processing
+ for _, module in self.model.named_modules():
+ quant_method = getattr(module, "quant_method", None)
+
+ # Check if the module supports quantization post-processing
+ if quant_method is not None and hasattr(
+ quant_method, "process_weights_after_loading"
+ ):
+
+ # Apply the post-processing (e.g., repacking weights for Marlin kernel)
+ with device_loading_context(module, target_device):
+ quant_method.process_weights_after_loading(module)
+
+ return True, "Success"
def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
params_dict = dict(model.named_parameters())
diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py
index ed8cc7ada..b8f1026dd 100644
--- a/python/sglang/srt/models/deepseek_v2.py
+++ b/python/sglang/srt/models/deepseek_v2.py
@@ -159,6 +159,7 @@ from sglang.srt.utils import (
make_layers,
use_intel_amx_backend,
)
+from sglang.srt.layers.attention.hybrid_attn_backend import HybridAttnBackend
_is_hip = is_hip()
_is_cuda = is_cuda()
@@ -434,6 +435,8 @@ def handle_attention_nsa(attn, forward_batch):
backend = forward_batch.attn_backend
if isinstance(backend, TboAttnBackend): # if enable tbo, get primary backend
backend = backend.primary
+ if isinstance(backend, HybridAttnBackend):
+ backend = backend._select_backend(forward_batch.forward_mode)
if hasattr(backend, "use_mha") and backend.use_mha:
return AttnForwardMethod.MHA_ONE_SHOT
return AttnForwardMethod.MLA
@@ -2704,7 +2707,11 @@ class DeepseekV2AttentionMLA(nn.Module):
):
k = k_nope.new_empty(*k_shape)
concat_mla_k(k=k, k_nope=k_nope, k_rope=k_pe)
- elif _is_cuda:
+ elif _is_cuda and all(
+ # (i.bit_count() == 1) == (is_power_of_two(i))
+ i.bit_count() == 1
+ for i in (k_shape[1], k_nope.shape[-1], k_pe.shape[-1])
+ ):
# fa3 mha support fp8 inputs
if (
self.current_attention_backend == "fa3"
diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py
index a7dbadec6..c83a41338 100644
--- a/python/sglang/srt/models/qwen2.py
+++ b/python/sglang/srt/models/qwen2.py
@@ -90,9 +90,6 @@ class Qwen2MLP(nn.Module):
self.act_fn = SiluAndMul()
def forward(self, x):
- if get_global_server_args().rl_on_policy_target is not None:
- x = x.bfloat16()
-
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
@@ -279,11 +276,6 @@ class Qwen2Model(nn.Module):
quant_config=quant_config,
enable_tp=not is_dp_attention_enabled(),
prefix=add_prefix("embed_tokens", prefix),
- params_dtype=(
- torch.float32
- if get_global_server_args().rl_on_policy_target is not None
- else None
- ),
)
else:
self.embed_tokens = PPMissingLayer()
@@ -306,10 +298,8 @@ class Qwen2Model(nn.Module):
if self.pp_group.is_last_rank:
norm_kwargs = (
dict(
- weight_dtype=torch.float32,
cast_x_before_out_mul=True,
- override_orig_dtype=torch.float32,
- fp32_residual=True,
+ fp32_residual=False,
)
if get_global_server_args().rl_on_policy_target is not None
else {}
diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py
index 3ad9f6736..0b9c7f499 100644
--- a/python/sglang/srt/models/qwen2_moe.py
+++ b/python/sglang/srt/models/qwen2_moe.py
@@ -586,7 +586,17 @@ class Qwen2MoeModel(nn.Module):
prefix=add_prefix("layers", prefix),
)
if self.pp_group.is_last_rank:
- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ norm_kwargs = (
+ dict(
+ cast_x_before_out_mul=True,
+ fp32_residual=False,
+ )
+ if get_global_server_args().rl_on_policy_target is not None
+ else {}
+ )
+ self.norm = RMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs
+ )
else:
self.norm = PPMissingLayer(return_tuple=True)
diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py
index 9220831f6..2b8303b54 100644
--- a/python/sglang/srt/models/qwen3.py
+++ b/python/sglang/srt/models/qwen3.py
@@ -90,8 +90,8 @@ class Qwen3Attention(nn.Module):
norm_kwargs = (
dict(
- weight_dtype=torch.float32,
cast_x_before_out_mul=True,
+ fp32_residual=False,
)
if get_global_server_args().rl_on_policy_target is not None
else {}
@@ -242,10 +242,8 @@ class Qwen3DecoderLayer(nn.Module):
norm_kwargs = (
dict(
- weight_dtype=torch.float32,
cast_x_before_out_mul=True,
- override_orig_dtype=torch.float32,
- fp32_residual=True,
+ fp32_residual=False,
)
if get_global_server_args().rl_on_policy_target is not None
else {}
@@ -276,14 +274,14 @@ class Qwen3DecoderLayer(nn.Module):
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
- **kwargs,
+ post_residual_addition: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
hidden_states, residual = self.layer_communicator.prepare_attn(
hidden_states,
residual,
forward_batch,
- **kwargs,
+ post_residual_addition=post_residual_addition,
)
if hidden_states.shape[0] != 0:
hidden_states = self.self_attn(
diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py
index e11678a9e..e277d46f2 100644
--- a/python/sglang/srt/models/qwen3_moe.py
+++ b/python/sglang/srt/models/qwen3_moe.py
@@ -22,6 +22,7 @@ import math
from typing import Any, Dict, Iterable, List, Optional, Tuple, TypeVar
import torch
+import torch.nn.functional as F
from torch import nn
from transformers import PretrainedConfig
@@ -50,7 +51,7 @@ from sglang.srt.layers.moe import (
)
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
-from sglang.srt.layers.moe.topk import TopK
+from sglang.srt.layers.moe.topk import StandardTopKOutput, TopK
from sglang.srt.layers.moe.utils import RoutingMethodType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
@@ -229,6 +230,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
use_grouped_topk=False,
layer_id=layer_id,
)
+ self.top_k = config.num_experts_per_tok
self.experts = get_moe_impl_class(quant_config)(
num_experts=config.num_experts
@@ -294,7 +296,22 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
- topk_output = self.topk(hidden_states, router_logits)
+
+ if get_global_server_args().rl_on_policy_target is not None:
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
+ routing_weights, selected_experts = torch.topk(
+ routing_weights, self.top_k, dim=-1
+ )
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
+ routing_weights = routing_weights.to(hidden_states.dtype)
+ topk_output = StandardTopKOutput(
+ topk_weights=routing_weights,
+ topk_ids=selected_experts,
+ router_logits=router_logits,
+ )
+ else:
+ topk_output = self.topk(hidden_states, router_logits)
+
final_hidden_states = self.experts(hidden_states, topk_output)
if (
self.tp_size > 1
@@ -475,13 +492,14 @@ class Qwen3MoeAttention(nn.Module):
)
self.compatible_with_fused_kv_buffer = (
False if isinstance(self.rotary_emb, MRotaryEmbedding) else True
- )
+ ) and (get_global_server_args().rl_on_policy_target is None)
self.compatible_with_fused_qk_norm_rope = (
not isinstance(self.rotary_emb, MRotaryEmbedding)
) and self.head_dim in (64, 128, 256)
self.use_fused_qk_norm_rope = (
get_global_server_args().enable_fused_qk_norm_rope
and self.compatible_with_fused_qk_norm_rope
+ and (get_global_server_args().rl_on_policy_target is None)
)
self._used_fused_qk_norm_rope_last_call = False
@@ -494,8 +512,16 @@ class Qwen3MoeAttention(nn.Module):
prefix=add_prefix("attn", prefix),
)
- self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
- self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
+ norm_kwargs = (
+ dict(
+ cast_x_before_out_mul=True,
+ fp32_residual=False,
+ )
+ if get_global_server_args().rl_on_policy_target is not None
+ else {}
+ )
+ self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs)
+ self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs)
self.alt_stream = alt_stream
def op_prepare(self, state):
@@ -736,9 +762,19 @@ class Qwen3MoeDecoderLayer(nn.Module):
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
- self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ norm_kwargs = (
+ dict(
+ cast_x_before_out_mul=True,
+ fp32_residual=False,
+ )
+ if get_global_server_args().rl_on_policy_target is not None
+ else {}
+ )
+ self.input_layernorm = RMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs
+ )
self.post_attention_layernorm = RMSNorm(
- config.hidden_size, eps=config.rms_norm_eps
+ config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs
)
self.layer_communicator = LayerCommunicator(
diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py
index 079f45843..218e32362 100644
--- a/python/sglang/srt/models/qwen3_vl.py
+++ b/python/sglang/srt/models/qwen3_vl.py
@@ -397,28 +397,68 @@ class Qwen3VLMoeVisionModel(nn.Module, RotaryPosMixin):
return cos_combined, sin_combined
def fast_pos_embed_interpolate(self, grid_thw):
- patch_pos_embeds_permute = []
- m_size = self.spatial_merge_size
+ grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2]
+ num_grid_per_side = int(self.num_position_embeddings**0.5)
+ device = self.pos_embed.weight.device
+
+ idx_list = [[] for _ in range(4)]
+ weight_list = [[] for _ in range(4)]
+
+ for t, h, w in zip(grid_ts, grid_hs, grid_ws):
+ h_idxs = torch.linspace(0, num_grid_per_side - 1, h)
+ w_idxs = torch.linspace(0, num_grid_per_side - 1, w)
+
+ h_idxs_floor = h_idxs.int()
+ w_idxs_floor = w_idxs.int()
+ h_idxs_ceil = (h_idxs.int() + 1).clip(max=num_grid_per_side - 1)
+ w_idxs_ceil = (w_idxs.int() + 1).clip(max=num_grid_per_side - 1)
+
+ dh = h_idxs - h_idxs_floor
+ dw = w_idxs - w_idxs_floor
+
+ base_h = h_idxs_floor * num_grid_per_side
+ base_h_ceil = h_idxs_ceil * num_grid_per_side
+
+ indices = [
+ (base_h[None].T + w_idxs_floor[None]).flatten(),
+ (base_h[None].T + w_idxs_ceil[None]).flatten(),
+ (base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
+ (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
+ ]
+
+ weights = [
+ ((1 - dh)[None].T * (1 - dw)[None]).flatten(),
+ ((1 - dh)[None].T * dw[None]).flatten(),
+ (dh[None].T * (1 - dw)[None]).flatten(),
+ (dh[None].T * dw[None]).flatten(),
+ ]
- embeds = torch.arange(self.num_grid, device=self.pos_embed.weight.device)
- embeds = (
- self.pos_embed(embeds)
- .permute(1, 0)
- .reshape(1, -1, self.num_grid_per_side, self.num_grid_per_side)
+ for i in range(4):
+ idx_list[i].extend(indices[i].tolist())
+ weight_list[i].extend(weights[i].tolist())
+
+ idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device)
+ weight_tensor = torch.tensor(
+ weight_list, dtype=self.pos_embed.weight.dtype, device=device
)
- for t, h, w in grid_thw:
- pos_embed = torch.nn.functional.interpolate(
- embeds, size=(h, w), mode="bilinear", align_corners=self.align_corners
- )
- pos_embed = pos_embed.reshape(
- -1,
- h // self.spatial_merge_size,
- self.spatial_merge_size,
- w // self.spatial_merge_size,
- self.spatial_merge_size,
+ pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None]
+ patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
+
+ patch_pos_embeds = patch_pos_embeds.split(
+ [h * w for h, w in zip(grid_hs, grid_ws)]
+ )
+
+ patch_pos_embeds_permute = []
+ merge_size = self.spatial_merge_size
+ for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
+ pos_embed = pos_embed.repeat(t, 1)
+ pos_embed = (
+ pos_embed.view(
+ t, h // merge_size, merge_size, w // merge_size, merge_size, -1
+ )
+ .permute(0, 1, 3, 2, 4, 5)
+ .flatten(0, 4)
)
- pos_embed = pos_embed.permute(1, 3, 2, 4, 0)
- pos_embed = pos_embed.flatten(0, 3).repeat(t, 1)
patch_pos_embeds_permute.append(pos_embed)
return torch.cat(patch_pos_embeds_permute)
@@ -610,14 +650,19 @@ class Qwen3LLMModel(Qwen3Model):
hidden_states + residual if residual is not None else hidden_states
)
+ deepstack_embeds = None
+ if input_deepstack_embeds is not None:
+ prev_layer_idx = layer_idx - 1
+ if prev_layer_idx in self.deepstack_embed_to_decoder_layer:
+ sep = self.hidden_size * prev_layer_idx
+ deepstack_embeds = input_deepstack_embeds[
+ :, sep : sep + self.hidden_size
+ ]
+
# SGLang applies residual at the START of the next layer, not at the END like HuggingFace.
# See: https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py#L549
# To match HF behavior, deepstack must be added AFTER residual: (hidden_states + residual) + deepstack
# The order matters because addition with different tensors is not associative in practice.
- # Deepstack for prev_layer is applied at the start of current layer via post_residual_addition.
- deepstack_embeds = self.get_deepstack_embeds(
- layer_idx - 1, input_deepstack_embeds
- )
hidden_states, residual = layer(
positions,
hidden_states,
diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py
index a2b26e0e0..72db29801 100644
--- a/python/sglang/srt/server_args.py
+++ b/python/sglang/srt/server_args.py
@@ -527,6 +527,7 @@ class ServerArgs:
cuda_graph_max_bs: Optional[int] = None
cuda_graph_bs: Optional[List[int]] = None
disable_cuda_graph: bool = False
+ disable_draft_cuda_graph: bool = False
disable_cuda_graph_padding: bool = False
enable_profile_cuda_graph: bool = False
enable_cudagraph_gc: bool = False
@@ -3980,6 +3981,11 @@ class ServerArgs:
action="store_true",
help="Disable cuda graph.",
)
+ parser.add_argument(
+ "--disable-draft-cuda-graph",
+ action="store_true",
+ help="Disable cuda graph for draft model in speculative decoding.",
+ )
parser.add_argument(
"--disable-cuda-graph-padding",
action="store_true",
diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
index 5fe45086c..c95fbd0f6 100644
--- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
+++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
@@ -341,7 +341,10 @@ class EAGLEDraftCudaGraphRunner:
self.seq_lens.fill_(self.seq_len_fill_value)
self.out_cache_loc.zero_()
self.positions.zero_()
-
+ self.topk_p.zero_()
+ self.topk_index.zero_()
+ self.hidden_states.zero_()
+ self.req_pool_indices.zero_()
num_tokens = bs * self.num_tokens_per_bs
# Common inputs
@@ -350,8 +353,8 @@ class EAGLEDraftCudaGraphRunner:
forward_batch.out_cache_loc
)
self.positions[:raw_num_token].copy_(forward_batch.positions)
- self.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p)
- self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
+ self.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p.clamp(0, 1))
+ self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index.clamp(0, self.model_runner.model_config.vocab_size - 1))
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py
index 1bf3816e9..b5b41dba4 100644
--- a/python/sglang/srt/speculative/eagle_info.py
+++ b/python/sglang/srt/speculative/eagle_info.py
@@ -778,6 +778,10 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
self.topk_index = self.topk_index[: len(new_indices)]
self.hidden_states = self.hidden_states[: len(new_indices)]
self.verified_id = self.verified_id[: len(new_indices)]
+ if self.accept_length is not None:
+ self.accept_length = self.accept_length[: len(new_indices)]
+ if self.accept_length_cpu is not None:
+ self.accept_length_cpu = self.accept_length_cpu[: len(new_indices)]
else:
# in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index`
self.topk_p = self.topk_p[new_indices]
@@ -809,6 +813,27 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
self.topk_p = torch.cat([self.topk_p, spec_info.topk_p])
self.topk_index = torch.cat([self.topk_index, spec_info.topk_index])
+ if self.accept_length is not None and spec_info.accept_length is not None:
+ self.accept_length = torch.cat(
+ [self.accept_length, spec_info.accept_length]
+ )
+ self.accept_length_cpu = self.accept_length.tolist()
+ elif self.accept_length is not None:
+ zeros = torch.zeros(
+ [spec_info.verified_id.shape[0]],
+ dtype=self.accept_length.dtype,
+ device=self.accept_length.device,
+ )
+ self.accept_length = torch.cat([self.accept_length, zeros])
+ self.accept_length_cpu = self.accept_length.tolist()
+ elif spec_info.accept_length is not None:
+ zeros = torch.zeros(
+ [self.verified_id.shape[0]],
+ dtype=self.accept_length.dtype,
+ device=self.accept_length.device,
+ )
+ self.accept_length = torch.cat([zeros, spec_info.accept_length])
+ self.accept_length_cpu = self.accept_length.tolist()
@dataclass
diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py
index a702df4f8..61d9ae366 100644
--- a/python/sglang/srt/speculative/eagle_worker.py
+++ b/python/sglang/srt/speculative/eagle_worker.py
@@ -231,7 +231,7 @@ class EAGLEWorker(TpModelWorker):
self.cuda_graph_runner = None
self.cuda_graph_runner_for_draft_extend = None
- if self.server_args.disable_cuda_graph:
+ if self.server_args.disable_cuda_graph or self.server_args.disable_draft_cuda_graph:
return
Device2DraftCudaGraphRunner = {
diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py
index 8560246c6..13db860dc 100644
--- a/python/sglang/srt/utils/common.py
+++ b/python/sglang/srt/utils/common.py
@@ -2224,6 +2224,8 @@ class SafeUnpickler(pickle.Unpickler):
"sglang.srt.model_executor.model_runner.",
"sglang.srt.layers.",
"sglang.srt.utils.",
+ # --- slime ---
+ "slime.",
}
DENY_CLASSES = {