| |
| |
| |
| |
| @@ -268,6 +268,12 @@ class ModelConfig: |
| ): |
| self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN" |
| |
| + if ( |
| + is_draft_model |
| + and self.hf_config.architectures[0] == "DeepseekV32ForCausalLM" |
| + ): |
| + self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN" |
| + |
| if is_draft_model and self.hf_config.architectures[0] == "Glm4MoeForCausalLM": |
| self.hf_config.architectures[0] = "Glm4MoeForCausalLMNextN" |
| |
| |
| |
| |
| |
| @@ -21,6 +21,7 @@ Life cycle of a request in the decode server |
| from __future__ import annotations |
| |
| import logging |
| +import os |
| import time |
| from collections import deque |
| from dataclasses import dataclass |
| @@ -315,6 +316,16 @@ class DecodePreallocQueue: |
| ) |
| return kv_manager |
| |
| + def release_memory_occupation(self): |
| + self.queue.clear() |
| + self.retracted_queue.clear() |
| + if hasattr(self.kv_manager, "deregister_buffer_to_engine"): |
| + self.kv_manager.deregister_buffer_to_engine() |
| + |
| + def resume_memory_occupation(self): |
| + if hasattr(self.kv_manager, "register_buffer_to_engine"): |
| + self.kv_manager.register_buffer_to_engine() |
| + |
| def add(self, req: Req, is_retracted: bool = False) -> None: |
| """Add a request to the pending queue.""" |
| if self._check_if_req_exceed_kv_capacity(req): |
| @@ -419,12 +430,37 @@ class DecodePreallocQueue: |
| [decode_req.kv_receiver for decode_req in self.queue], self.gloo_group |
| ) |
| |
| + # Bootstrap timeout: if a request has been stuck in Bootstrapping for too long, treat it as failed. |
| + bootstrap_timeout = float( |
| + os.environ.get("SGLANG_DISAGGREGATION_TRANSFER_TIMEOUT", "600") |
| + ) |
| + now = time.perf_counter() |
| + |
| for i, (decode_req, poll) in enumerate(zip(self.queue, polls)): |
| if rids_to_check is not None and decode_req.req.rid not in rids_to_check: |
| continue |
| |
| if poll == KVPoll.Bootstrapping: |
| - pass |
| + # Check for bootstrap timeout |
| + entry_time = getattr( |
| + decode_req.req.time_stats, |
| + "decode_prealloc_queue_entry_time", |
| + None, |
| + ) |
| + if entry_time is not None and (now - entry_time) > bootstrap_timeout: |
| + error_message = ( |
| + f"Decode bootstrap timed out after {now - entry_time:.1f}s " |
| + f"for request rank={self.tp_rank} " |
| + f"{decode_req.req.rid=} {decode_req.req.bootstrap_room=}" |
| + ) |
| + logger.error(error_message) |
| + prepare_abort( |
| + decode_req.req, |
| + error_message, |
| + status_code=HTTPStatus.GATEWAY_TIMEOUT, |
| + ) |
| + if self.scheduler.enable_metrics: |
| + self.scheduler.metrics_collector.increment_bootstrap_failed_reqs() |
| elif poll == KVPoll.WaitingForInput: |
| decode_req.waiting_for_input = True |
| elif poll == KVPoll.Failed: |
| @@ -776,6 +812,13 @@ class DecodeTransferQueue: |
| [decode_req.kv_receiver for decode_req in self.queue], self.gloo_group |
| ) |
| |
| + # Transfer timeout: if a request has been in the transfer queue for too long |
| + # (e.g., stuck in Bootstrapping/WaitingForInput/Transferring), treat it as failed. |
| + transfer_timeout = float( |
| + os.environ.get("SGLANG_DISAGGREGATION_TRANSFER_TIMEOUT", "600") |
| + ) |
| + now = time.perf_counter() |
| + |
| transferred_reqs = [] |
| indices_to_remove = set() |
| for i, (decode_req, poll) in enumerate(zip(self.queue, polls)): |
| @@ -811,7 +854,33 @@ class DecodeTransferQueue: |
| KVPoll.WaitingForInput, |
| KVPoll.Transferring, |
| ]: |
| - pass |
| + # Check for transfer timeout |
| + entry_time = getattr( |
| + decode_req.req.time_stats, |
| + "decode_transfer_queue_entry_time", |
| + None, |
| + ) |
| + if entry_time is not None and (now - entry_time) > transfer_timeout: |
| + error_message = ( |
| + f"Decode transfer timed out after {now - entry_time:.1f}s " |
| + f"(state={poll}) for request rank={self.tp_rank} " |
| + f"{decode_req.req.rid=} {decode_req.req.bootstrap_room=}" |
| + ) |
| + logger.error(error_message) |
| + prepare_abort( |
| + decode_req.req, |
| + error_message, |
| + status_code=HTTPStatus.GATEWAY_TIMEOUT, |
| + ) |
| + self.scheduler.stream_output( |
| + [decode_req.req], decode_req.req.return_logprob |
| + ) |
| + release_kv_cache( |
| + decode_req.req, self.tree_cache, is_insert=False |
| + ) |
| + indices_to_remove.add(i) |
| + if self.scheduler.enable_metrics: |
| + self.scheduler.metrics_collector.increment_transfer_failed_reqs() |
| else: |
| raise ValueError(f"Unexpected poll case: {poll}") |
| |
| @@ -827,6 +896,14 @@ class DecodeTransferQueue: |
| |
| return transferred_reqs |
| |
| + def release_memory_occupation(self): |
| + """Clean up all in-flight transfers before releasing GPU memory.""" |
| + self.queue.clear() |
| + |
| + def resume_memory_occupation(self): |
| + """Resume after GPU memory re-allocation. Queue was already cleared on release.""" |
| + pass |
| + |
| |
| class SchedulerDisaggregationDecodeMixin: |
| |
| @@ -1004,7 +1081,15 @@ class SchedulerDisaggregationDecodeMixin: |
| resumed_reqs = self.disagg_decode_prealloc_queue.resume_retracted_reqs() |
| self.waiting_queue.extend(resumed_reqs) |
| if len(self.disagg_decode_prealloc_queue.retracted_queue) > 0: |
| - # if there are still retracted requests, we do not allocate new requests |
| + # Still have retracted requests that couldn't resume (not enough memory). |
| + # Don't accept new requests (pop_preallocated) — they would consume memory |
| + # that retracted requests need. |
| + # But DO drain completed transfers: their KV is already committed, and |
| + # moving them to waiting_queue frees the reserved-decode-token budget |
| + # in _allocatable_tokens(), which may unblock resume on the next iteration. |
| + # Without this, completed transfers hold memory indefinitely → deadlock. |
| + alloc_reqs = self.disagg_decode_transfer_queue.pop_transferred() |
| + self.waiting_queue.extend(alloc_reqs) |
| return |
| |
| if not hasattr(self, "polling_count"): |
| |
| |
| |
| |
| @@ -253,6 +253,19 @@ class MooncakeKVManager(CommonKVManager): |
| self.kv_args.state_data_ptrs, self.kv_args.state_data_lens |
| ) |
| |
| + def deregister_buffer_to_engine(self): |
| + # Batch deregister KV data buffers |
| + if self.kv_args.kv_data_ptrs: |
| + self.engine.batch_deregister(self.kv_args.kv_data_ptrs) |
| + |
| + # Batch deregister auxiliary data buffers |
| + if self.kv_args.aux_data_ptrs: |
| + self.engine.batch_deregister(self.kv_args.aux_data_ptrs) |
| + |
| + # Batch deregister state/extra pool data buffers |
| + if self.kv_args.state_data_ptrs: |
| + self.engine.batch_deregister(self.kv_args.state_data_ptrs) |
| + |
| def _transfer_data(self, mooncake_session_id, transfer_blocks): |
| if not transfer_blocks: |
| return 0 |
| |
| |
| |
| |
| @@ -20,6 +20,7 @@ Life cycle of a request in the prefill server |
| from __future__ import annotations |
| |
| import logging |
| +import os |
| import time |
| from collections import deque |
| from http import HTTPStatus |
| @@ -250,6 +251,12 @@ class PrefillBootstrapQueue: |
| [req.disagg_kv_sender for req in self.queue], self.gloo_group |
| ) |
| |
| + # Bootstrap timeout: if a request has been stuck in Bootstrapping for too long, treat it as failed. |
| + bootstrap_timeout = float( |
| + os.environ.get("SGLANG_DISAGGREGATION_TRANSFER_TIMEOUT", "600") |
| + ) |
| + now = time.perf_counter() |
| + |
| for i, (req, poll) in enumerate(zip(self.queue, polls)): |
| if rids_to_check is not None: |
| # if req not in reqs_info_to_check, skip |
| @@ -257,6 +264,27 @@ class PrefillBootstrapQueue: |
| continue |
| |
| if poll == KVPoll.Bootstrapping: |
| + # Check for bootstrap timeout |
| + entry_time = getattr( |
| + req.time_stats, |
| + "prefill_bootstrap_queue_entry_time", |
| + None, |
| + ) |
| + if entry_time is not None and (now - entry_time) > bootstrap_timeout: |
| + error_message = ( |
| + f"Prefill bootstrap timed out after {now - entry_time:.1f}s " |
| + f"for request rank={self.tp_rank} " |
| + f"{req.rid=} {req.bootstrap_room=}" |
| + ) |
| + logger.error(error_message) |
| + prepare_abort( |
| + req, error_message, status_code=HTTPStatus.GATEWAY_TIMEOUT |
| + ) |
| + self.scheduler.stream_output([req], req.return_logprob) |
| + indices_to_remove.add(i) |
| + failed_reqs.append(req) |
| + if self.scheduler.enable_metrics: |
| + self.scheduler.metrics_collector.increment_bootstrap_failed_reqs() |
| continue |
| elif poll == KVPoll.Failed: |
| error_message = f"Prefill bootstrap failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}" |
| @@ -306,6 +334,15 @@ class PrefillBootstrapQueue: |
| else: |
| return bootstrapped_reqs, failed_reqs |
| |
| + def release_memory_occupation(self): |
| + self.queue.clear() |
| + if hasattr(self.kv_manager, "deregister_buffer_to_engine"): |
| + self.kv_manager.deregister_buffer_to_engine() |
| + |
| + def resume_memory_occupation(self): |
| + if hasattr(self.kv_manager, "register_buffer_to_engine"): |
| + self.kv_manager.register_buffer_to_engine() |
| + |
| |
| class SchedulerDisaggregationPrefillMixin: |
| """ |
| @@ -535,6 +572,13 @@ class SchedulerDisaggregationPrefillMixin: |
| self.attn_tp_cpu_group, |
| ) |
| |
| + # Transfer timeout: if a request has been in the inflight queue for too long |
| + # (e.g., stuck in WaitingForInput/Transferring), treat it as failed. |
| + transfer_timeout = float( |
| + os.environ.get("SGLANG_DISAGGREGATION_TRANSFER_TIMEOUT", "600") |
| + ) |
| + now = time.perf_counter() |
| + |
| undone_reqs: List[Req] = [] |
| # Check .poll() for the reqs in disagg_prefill_inflight_queue. If Success, respond to the client and remove it from the queue |
| for req, poll in zip(self.disagg_prefill_inflight_queue, polls): |
| @@ -547,7 +591,30 @@ class SchedulerDisaggregationPrefillMixin: |
| assert poll == KVPoll.Success or poll == KVPoll.Failed |
| |
| if poll in [KVPoll.WaitingForInput, KVPoll.Transferring]: |
| - undone_reqs.append(req) |
| + # Check for transfer timeout |
| + entry_time = getattr( |
| + req.time_stats, |
| + "prefill_transfer_queue_entry_time", |
| + None, |
| + ) |
| + if entry_time is not None and (now - entry_time) > transfer_timeout: |
| + error_message = ( |
| + f"Prefill transfer timed out after {now - entry_time:.1f}s " |
| + f"(state={poll}) for request rank={self.tp_rank} " |
| + f"{req.rid=} {req.bootstrap_room=}" |
| + ) |
| + logger.error(error_message) |
| + release_kv_cache(req, self.tree_cache) # unlock the tree |
| + prepare_abort( |
| + req, error_message, status_code=HTTPStatus.GATEWAY_TIMEOUT |
| + ) |
| + if hasattr(req.disagg_kv_sender, "clear"): |
| + req.disagg_kv_sender.clear() |
| + done_reqs.append(req) |
| + if self.enable_metrics: |
| + self.metrics_collector.increment_transfer_failed_reqs() |
| + else: |
| + undone_reqs.append(req) |
| elif poll == KVPoll.Success: # transfer done |
| release_kv_cache(req, self.tree_cache) # unlock the tree |
| req.finished_reason = FINISH_LENGTH(length=0) |
| |
| |
| |
| |
| @@ -1797,7 +1797,10 @@ def get_tensor_model_parallel_world_size(): |
| |
| def get_tensor_model_parallel_rank(): |
| """Return my rank for the tensor model parallel group.""" |
| - return get_tp_group().rank_in_group |
| + try: |
| + return get_tp_group().rank_in_group |
| + except Exception: |
| + return 0 |
| |
| |
| def get_pipeline_model_parallel_world_size(): |
| |
| |
| |
| |
| @@ -49,6 +49,7 @@ from sglang.srt.managers.io_struct import ( |
| InitWeightsUpdateGroupReqInput, |
| LoadLoRAAdapterReqInput, |
| MultimodalDataInputFormat, |
| + PostProcessWeightsReqInput, |
| ReleaseMemoryOccupationReqInput, |
| ResumeMemoryOccupationReqInput, |
| RpcReqInput, |
| @@ -593,6 +594,24 @@ class Engine(EngineBase): |
| self.tokenizer_manager.update_weights_from_ipc(obj, None) |
| ) |
| |
| + def post_process_weights( |
| + self, |
| + restore_weights_before_load: bool = False, |
| + post_process_quantization: bool = False, |
| + ): |
| + """ |
| + Optional post-processing for updated weights (e.g., Marlin conversion). |
| + Should be called after weight update is finished. |
| + """ |
| + obj = PostProcessWeightsReqInput( |
| + restore_weights_before_load=restore_weights_before_load, |
| + post_process_quantization=post_process_quantization, |
| + ) |
| + |
| + return self.loop.run_until_complete( |
| + self.tokenizer_manager.post_process_weights(obj, None) |
| + ) |
| + |
| def get_weights_by_name(self, name: str, truncate_size: int = 100): |
| """Get weights by parameter name.""" |
| obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) |
| |
| |
| |
| |
| @@ -107,6 +107,7 @@ from sglang.srt.managers.io_struct import ( |
| OpenSessionReqInput, |
| ParseFunctionCallReq, |
| PauseGenerationReqInput, |
| + PostProcessWeightsReqInput, |
| ProfileReqInput, |
| ReleaseMemoryOccupationReqInput, |
| ResumeMemoryOccupationReqInput, |
| @@ -957,6 +958,21 @@ async def update_weights_from_ipc(obj: UpdateWeightsFromIPCReqInput, request: Re |
| else: |
| return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) |
| |
| +@app.post("/post_process_weights") |
| +async def post_process_weights(req: PostProcessWeightsReqInput, request: Request): |
| + """ |
| + Optional post-processing for updated weights (e.g., Marlin conversion). |
| + This should be called selectively after `update_weights_from_distributed/update_weights_from_tensor`. |
| + """ |
| + success, message = await _global_state.tokenizer_manager.post_process_weights( |
| + req, request |
| + ) |
| + |
| + content = {"success": success, "message": message} |
| + return ORJSONResponse( |
| + content, status_code=200 if success else HTTPStatus.BAD_REQUEST |
| + ) |
| + |
| |
| @app.post("/update_weight_version") |
| async def update_weight_version(obj: UpdateWeightVersionReqInput, request: Request): |
| |
| |
| |
| |
| @@ -613,7 +613,6 @@ def _get_k_and_s_triton( |
| page_indices, |
| k_out, |
| s_out, |
| - seq_len, |
| page_size, |
| buf_numel_per_page, |
| index_head_dim, |
| @@ -630,7 +629,6 @@ def _get_k_and_s_triton_kernel( |
| page_indices_ptr, |
| k_out_ptr, |
| s_out_ptr, |
| - seq_len: tl.constexpr, |
| page_size: tl.constexpr, |
| buf_numel_per_page: tl.constexpr, |
| index_head_dim: tl.constexpr, |
| |
| |
| |
| |
| @@ -3,6 +3,7 @@ from __future__ import annotations |
| from abc import ABC, abstractmethod |
| from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple |
| |
| +import os |
| import torch |
| from einops import rearrange |
| |
| @@ -178,7 +179,7 @@ class Indexer(MultiPlatformOp): |
| max_position=max_position_embeddings, |
| base=rope_theta, # type: ignore |
| rope_scaling=rope_scaling, |
| - is_neox_style=True, |
| + is_neox_style=True if os.environ.get("INDEXER_ROPE_NEOX_STYLE", "1") == "1" else False, |
| device=get_global_server_args().device, |
| ) |
| self.block_size = block_size |
| @@ -188,6 +189,9 @@ class Indexer(MultiPlatformOp): |
| @torch.compile(dynamic=True) |
| def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor): |
| weights, _ = self.weights_proj(x.float()) |
| + if weights.shape[1] < 32: |
| + assert 32 % weights.shape[1] == 0 |
| + weights = weights.repeat_interleave(32 // weights.shape[1], dim=1) |
| weights = weights * self.n_heads**-0.5 |
| weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale |
| return weights |
| @@ -837,6 +841,9 @@ class Indexer(MultiPlatformOp): |
| query, key = self._get_q_k_bf16( |
| q_lora, x, positions, enable_dual_stream, forward_batch=forward_batch |
| ) |
| + if query.shape[1] < 32: |
| + assert 32 % query.shape[1] == 0 |
| + query = query.repeat_interleave(32//query.shape[1], dim=1) |
| |
| if enable_dual_stream: |
| current_stream = torch.cuda.current_stream() |
| |
| |
| |
| |
| @@ -371,10 +371,13 @@ class LayerCommunicator: |
| residual: torch.Tensor, |
| forward_batch: ForwardBatch, |
| captured_last_layer_outputs: Optional[List[torch.Tensor]] = None, |
| - **kwargs, |
| + post_residual_addition: Optional[torch.Tensor] = None, |
| ): |
| hidden_states, residual = self.prepare_attn( |
| - hidden_states, residual, forward_batch, **kwargs |
| + hidden_states, |
| + residual, |
| + forward_batch, |
| + post_residual_addition=post_residual_addition, |
| ) |
| if captured_last_layer_outputs is not None: |
| gathered_last_layer_output = self._communicate_simple_fn( |
| @@ -394,7 +397,7 @@ class LayerCommunicator: |
| residual: torch.Tensor, |
| forward_batch: ForwardBatch, |
| quant_format: str = "", |
| - **kwargs, |
| + post_residual_addition: Optional[torch.Tensor] = None, |
| ): |
| if get_attn_tp_context().input_scattered: |
| hidden_states, residual = self._tp_reduce_scatter( |
| @@ -444,7 +447,7 @@ class LayerCommunicator: |
| ) |
| |
| else: |
| - hidden_states = self.input_layernorm(hidden_states, **kwargs) |
| + hidden_states = self.input_layernorm(hidden_states) |
| else: |
| |
| if _use_aiter and _is_gfx95_supported and ("mxfp4" in quant_format): |
| @@ -478,7 +481,7 @@ class LayerCommunicator: |
| hidden_states, residual = self.input_layernorm( |
| hidden_states, |
| residual, |
| - **kwargs, |
| + post_residual_addition, |
| ) |
| |
| hidden_states = self._communicate_simple_fn( |
| |
| |
| |
| |
| @@ -83,15 +83,12 @@ class RMSNorm(MultiPlatformOp): |
| eps: float = 1e-6, |
| var_hidden_size: Optional[int] = None, |
| cast_x_before_out_mul: bool = False, |
| - fp32_residual: bool = False, |
| - weight_dtype: Optional = None, |
| - override_orig_dtype: Optional = None, |
| + fp32_residual: bool = True, |
| ) -> None: |
| super().__init__() |
| self.cast_x_before_out_mul = cast_x_before_out_mul |
| self.fp32_residual = fp32_residual |
| - self.override_orig_dtype = override_orig_dtype |
| - self.weight = nn.Parameter(torch.ones(hidden_size, dtype=weight_dtype)) |
| + self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.variance_epsilon = eps |
| self.hidden_size = hidden_size |
| self.variance_size_override = ( |
| @@ -104,16 +101,16 @@ class RMSNorm(MultiPlatformOp): |
| self, |
| x: torch.Tensor, |
| residual: Optional[torch.Tensor] = None, |
| - **kwargs, |
| + post_residual_addition: Optional[torch.Tensor] = None, |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
| if self.variance_size_override is not None: |
| - return self.forward_native(x, residual, **kwargs) |
| + return self.forward_native(x, residual, post_residual_addition) |
| if is_batch_invariant_mode_enabled(): |
| if ( |
| residual is not None |
| or get_global_server_args().rl_on_policy_target == "fsdp" |
| ): |
| - return self.forward_native(x, residual, **kwargs) |
| + return self.forward_native(x, residual, post_residual_addition) |
| return rms_norm_batch_invariant( |
| x, |
| self.weight.data, |
| @@ -124,7 +121,6 @@ class RMSNorm(MultiPlatformOp): |
| # but right now we can only have hidden_states+(residual+post_residual_addition). |
| # (hidden_states+residual)+post_residual_addition != hidden_states+(residual+post_residual_addition), |
| # we probably need to add another parameter to fused_add_rmsnorm |
| - post_residual_addition = kwargs.get("post_residual_addition") |
| if post_residual_addition is not None: |
| residual = residual + post_residual_addition |
| fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon) |
| @@ -136,9 +132,11 @@ class RMSNorm(MultiPlatformOp): |
| self, |
| x: torch.Tensor, |
| residual: Optional[torch.Tensor] = None, |
| - **kwargs, |
| + post_residual_addition: Optional[torch.Tensor] = None, |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
| if residual is not None: |
| + if post_residual_addition is not None: |
| + residual = residual + post_residual_addition |
| out, _, residual_out = torch_npu.npu_add_rms_norm( |
| residual, x, self.weight.data, self.variance_epsilon |
| ) |
| @@ -149,9 +147,11 @@ class RMSNorm(MultiPlatformOp): |
| self, |
| x: torch.Tensor, |
| residual: Optional[torch.Tensor] = None, |
| - **kwargs, |
| + post_residual_addition: Optional[torch.Tensor] = None, |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
| if residual is not None: |
| + if post_residual_addition is not None: |
| + residual = residual + post_residual_addition |
| residual_out = torch.empty_like(x) |
| output = torch.empty_like(x) |
| fused_add_rms_norm( |
| @@ -169,12 +169,14 @@ class RMSNorm(MultiPlatformOp): |
| self, |
| x: torch.Tensor, |
| residual: Optional[torch.Tensor] = None, |
| - **kwargs, |
| + post_residual_addition: Optional[torch.Tensor] = None, |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
| if not x.is_contiguous(): |
| # NOTE: Remove this if aiter kernel supports discontinuous input |
| x = x.contiguous() |
| if residual is not None: |
| + if post_residual_addition is not None: |
| + residual = residual + post_residual_addition |
| out = torch.empty_like(x) |
| residual_out = torch.empty_like(x) |
| fused_add_rms_norm( |
| @@ -189,27 +191,23 @@ class RMSNorm(MultiPlatformOp): |
| self, |
| x: torch.Tensor, |
| residual: Optional[torch.Tensor] = None, |
| - **kwargs, |
| + post_residual_addition: Optional[torch.Tensor] = None, |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
| if not x.is_contiguous(): |
| x = x.contiguous() |
| - orig_dtype = self.override_orig_dtype or x.dtype |
| - post_residual_addition = kwargs.get("post_residual_addition") |
| + orig_dtype = x.dtype |
| + |
| + if residual is not None and not self.fp32_residual: |
| + x = x + residual |
| + if post_residual_addition is not None: |
| + x = x + post_residual_addition |
| + residual = x.clone() |
| x = x.to(torch.float32) |
| - if residual is not None: |
| - x = ( |
| - x |
| - + residual.to(torch.float32) |
| - + ( |
| - post_residual_addition.to(torch.float32) |
| - if post_residual_addition is not None |
| - else 0.0 |
| - ) |
| - ) |
| - if self.fp32_residual: |
| - residual = x.clone() |
| - else: |
| - residual = x.to(orig_dtype) |
| + if residual is not None and self.fp32_residual: |
| + x = x + residual.to(torch.float32) |
| + if post_residual_addition is not None: |
| + x = x + post_residual_addition.to(torch.float32) |
| + residual = x.to(orig_dtype) |
| |
| hidden_size = x.shape[-1] |
| if hidden_size != self.hidden_size: |
| @@ -246,10 +244,12 @@ class RMSNorm(MultiPlatformOp): |
| self, |
| x: torch.Tensor, |
| residual: Optional[torch.Tensor] = None, |
| - **kwargs, |
| + post_residual_addition: Optional[torch.Tensor] = None, |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
| if _is_cpu_amx_available: |
| if residual is not None: |
| + if post_residual_addition is not None: |
| + residual = residual + post_residual_addition |
| torch.ops.sgl_kernel.fused_add_rmsnorm_cpu( |
| x, residual, self.weight.data, self.variance_epsilon |
| ) |
| @@ -258,17 +258,19 @@ class RMSNorm(MultiPlatformOp): |
| x, self.weight.data, self.variance_epsilon |
| ) |
| else: |
| - return self.forward_native(x, residual, **kwargs) |
| + return self.forward_native(x, residual, post_residual_addition) |
| |
| def forward_xpu( |
| self, |
| x: torch.Tensor, |
| residual: Optional[torch.Tensor] = None, |
| - **kwargs, |
| + post_residual_addition: Optional[torch.Tensor] = None, |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
| if self.variance_size_override is not None: |
| - return self.forward_native(x, residual, **kwargs) |
| + return self.forward_native(x, residual, post_residual_addition) |
| if residual is not None: |
| + if post_residual_addition is not None: |
| + residual = residual + post_residual_addition |
| fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon) |
| return x, residual |
| out = rmsnorm(x, self.weight.data, self.variance_epsilon) |
| @@ -278,6 +280,7 @@ class RMSNorm(MultiPlatformOp): |
| self, |
| x: torch.Tensor, |
| residual: Optional[torch.Tensor] = None, |
| + post_residual_addition: Optional[torch.Tensor] = None, |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
| """ |
| Forward method with allreduce fusion, prioritizing flashinfer fused operations |
| @@ -289,6 +292,8 @@ class RMSNorm(MultiPlatformOp): |
| ) |
| |
| if get_tensor_model_parallel_world_size() > 1: |
| + if post_residual_addition is not None: |
| + x = x + post_residual_addition |
| fused_result = flashinfer_allreduce_residual_rmsnorm( |
| input_tensor=x, |
| residual=residual, |
| @@ -298,7 +303,7 @@ class RMSNorm(MultiPlatformOp): |
| if fused_result[0] is not None: |
| return fused_result |
| |
| - return self.forward(x, residual) |
| + return self.forward(x, residual, post_residual_addition) |
| |
| |
| class LayerNorm(MultiPlatformOp): |
| @@ -323,7 +328,6 @@ class LayerNorm(MultiPlatformOp): |
| def forward_cuda( |
| self, |
| x: torch.Tensor, |
| - **kwargs, |
| ) -> torch.Tensor: |
| if ( |
| _flashinfer_layernorm_available |
| @@ -332,12 +336,11 @@ class LayerNorm(MultiPlatformOp): |
| ): |
| return layernorm(x, self.weight, self.bias, self.variance_epsilon) |
| else: |
| - return self.forward_native(x, **kwargs) |
| + return self.forward_native(x) |
| |
| def forward_native( |
| self, |
| x: torch.Tensor, |
| - **kwargs, |
| ) -> torch.Tensor: |
| weight = self.weight if self.elementwise_affine else None |
| bias = self.bias if self.use_bias else None |
| @@ -354,28 +357,25 @@ class LayerNorm(MultiPlatformOp): |
| def forward_hip( |
| self, |
| x: torch.Tensor, |
| - **kwargs, |
| ) -> torch.Tensor: |
| - return self.forward_native(x, **kwargs) |
| + return self.forward_native(x) |
| |
| def forward_npu( |
| self, |
| x: torch.Tensor, |
| - **kwargs, |
| ) -> torch.Tensor: |
| - return self.forward_native(x, **kwargs) |
| + return self.forward_native(x) |
| |
| def forward_cpu( |
| self, |
| x: torch.Tensor, |
| - **kwargs, |
| ) -> torch.Tensor: |
| if _is_cpu_amx_available: |
| return torch.ops.sgl_kernel.layernorm_cpu( |
| x, self.weight.data, self.variance_epsilon |
| ) |
| else: |
| - return self.forward_native(x, **kwargs) |
| + return self.forward_native(x) |
| |
| |
| class GemmaRMSNorm(MultiPlatformOp): |
| @@ -396,9 +396,11 @@ class GemmaRMSNorm(MultiPlatformOp): |
| self, |
| x: torch.Tensor, |
| residual: Optional[torch.Tensor] = None, |
| - **kwargs, |
| + post_residual_addition: Optional[torch.Tensor] = None, |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
| if residual is not None: |
| + if post_residual_addition is not None: |
| + residual = residual + post_residual_addition |
| gemma_fused_add_rmsnorm( |
| x, residual, self.weight.data, self.variance_epsilon |
| ) |
| @@ -410,11 +412,13 @@ class GemmaRMSNorm(MultiPlatformOp): |
| self, |
| x: torch.Tensor, |
| residual: Optional[torch.Tensor] = None, |
| - **kwargs, |
| + post_residual_addition: Optional[torch.Tensor] = None, |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
| orig_dtype = x.dtype |
| if residual is not None: |
| x = x + residual |
| + if post_residual_addition is not None: |
| + x = x + post_residual_addition |
| residual = x |
| |
| x = x.float() |
| @@ -428,18 +432,20 @@ class GemmaRMSNorm(MultiPlatformOp): |
| self, |
| x: torch.Tensor, |
| residual: Optional[torch.Tensor] = None, |
| - **kwargs, |
| + post_residual_addition: Optional[torch.Tensor] = None, |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
| - return self._forward_impl(x, residual, **kwargs) |
| + return self._forward_impl(x, residual, post_residual_addition) |
| |
| def forward_cpu( |
| self, |
| x: torch.Tensor, |
| residual: Optional[torch.Tensor] = None, |
| - **kwargs, |
| + post_residual_addition: Optional[torch.Tensor] = None, |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
| if _is_cpu_amx_available: |
| if residual is not None: |
| + if post_residual_addition is not None: |
| + residual = residual + post_residual_addition |
| torch.ops.sgl_kernel.gemma_fused_add_rmsnorm_cpu( |
| x, residual, self.weight.data, self.variance_epsilon |
| ) |
| @@ -447,16 +453,18 @@ class GemmaRMSNorm(MultiPlatformOp): |
| return torch.ops.sgl_kernel.gemma_rmsnorm_cpu( |
| x, self.weight.data, self.variance_epsilon |
| ) |
| - return self.forward_native(x, residual, **kwargs) |
| + return self.forward_native(x, residual, post_residual_addition) |
| |
| def forward_npu( |
| self, |
| x: torch.Tensor, |
| residual: Optional[torch.Tensor] = None, |
| - **kwargs, |
| + post_residual_addition: Optional[torch.Tensor] = None, |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
| if residual is not None: |
| x = x + residual |
| + if post_residual_addition is not None: |
| + x = x + post_residual_addition |
| residual = x |
| |
| x, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon) |
| @@ -466,9 +474,9 @@ class GemmaRMSNorm(MultiPlatformOp): |
| self, |
| x: torch.Tensor, |
| residual: Optional[torch.Tensor] = None, |
| - **kwargs, |
| + post_residual_addition: Optional[torch.Tensor] = None, |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
| - return self._forward_impl(x, residual, **kwargs) |
| + return self._forward_impl(x, residual, post_residual_addition) |
| |
| |
| class Gemma3RMSNorm(MultiPlatformOp): |
| @@ -481,22 +489,22 @@ class Gemma3RMSNorm(MultiPlatformOp): |
| def _norm(self, x): |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
| |
| - def forward_native(self, x, **kwargs): |
| + def forward_native(self, x): |
| output = self._norm(x.float()) |
| # Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16) |
| # See https://github.com/huggingface/transformers/pull/29402 |
| output = output * (1.0 + self.weight.float()) |
| return output.type_as(x) |
| |
| - def forward_cpu(self, x, **kwargs): |
| + def forward_cpu(self, x): |
| if _is_cpu_amx_available and x.stride(-1) == 1: |
| return torch.ops.sgl_kernel.gemma3_rmsnorm_cpu(x, self.weight, self.eps) |
| - return self.forward_native(x, **kwargs) |
| + return self.forward_native(x) |
| |
| - def forward_cuda(self, x, **kwargs): |
| - return self.forward_native(x, **kwargs) |
| + def forward_cuda(self, x): |
| + return self.forward_native(x) |
| |
| - def forward_npu(self, x, **kwargs): |
| + def forward_npu(self, x): |
| output, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.eps) |
| return output |
| |
| |
| |
| |
| |
| @@ -878,11 +878,6 @@ class LogitsProcessor(nn.Module): |
| None, # bias |
| True, # is_vnni |
| ) |
| - elif get_global_server_args().rl_on_policy_target is not None: |
| - # Due to tie-weight, we may not be able to change lm_head's weight dtype |
| - logits = torch.matmul( |
| - hidden_states.bfloat16(), lm_head.weight.T.bfloat16() |
| - ) |
| else: |
| logits = torch.matmul( |
| hidden_states.to(lm_head.weight.dtype), lm_head.weight.T |
| |
| |
| |
| |
| @@ -14,6 +14,7 @@ import torch.nn.functional as F |
| import triton.language as tl |
| |
| from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig |
| +from sglang.srt.server_args import get_global_server_args |
| from sglang.srt.utils import ( |
| cpu_has_amx_support, |
| get_bool_env_var, |
| @@ -573,7 +574,10 @@ def fused_experts_impl( |
| ).squeeze(dim=1) |
| else: |
| # According to micro benchmark results, torch.compile can get better performance for small token. |
| - if tokens_in_chunk <= 32: |
| + if ( |
| + not get_global_server_args().enable_deterministic_inference |
| + and tokens_in_chunk <= 32 |
| + ): |
| moe_sum_reduce_torch_compile( |
| intermediate_cache3.view(*intermediate_cache3.shape), |
| out_hidden_states[begin_chunk_idx:end_chunk_idx], |
| |
| |
| |
| |
| @@ -647,7 +647,7 @@ class FusedMoE(torch.nn.Module): |
| "CompressedTensorsWNA16MarlinMoEMethod", |
| "CompressedTensorsWNA16MoEMethod", |
| ] |
| - ) |
| + ) and "zero" not in weight_name |
| else loaded_weight |
| ) |
| |
| |
| |
| |
| |
| @@ -1,5 +1,6 @@ |
| import logging |
| from abc import ABC |
| +from contextlib import contextmanager |
| from typing import Optional |
| |
| import numpy as np |
| @@ -8,13 +9,18 @@ import torch |
| |
| from sglang.srt.configs.model_config import ModelConfig |
| from sglang.srt.layers.dp_attention import ( |
| + attn_tp_all_gather_into_tensor, |
| get_attention_dp_rank, |
| + get_attention_tp_size, |
| get_dp_local_info, |
| is_dp_attention_enabled, |
| ) |
| from sglang.srt.mem_cache.memory_pool import ReqToTokenPool |
| from sglang.srt.model_executor.forward_batch_info import ForwardBatch |
| from sglang.srt.server_args import get_global_server_args |
| +from sglang.srt.layers.moe import ( |
| + get_moe_a2a_backend, |
| +) |
| |
| logger = logging.getLogger(__name__) |
| |
| @@ -181,13 +187,26 @@ class _RoutedExpertsCapturerReal(RoutedExpertsCapturer): |
| device=device, |
| ) |
| |
| + if get_moe_a2a_backend().is_deepep(): |
| + attn_tp_size = get_attention_tp_size() if is_dp_attention_enabled() else 1 |
| + self.gather_buffer = torch.empty( |
| + ( |
| + self.device_cache.buffer.shape[0] * attn_tp_size, |
| + self.device_cache.buffer.shape[2], |
| + ), |
| + dtype=torch.int32, |
| + device=device, |
| + ) |
| + |
| def _sync_fwd_experts_buffer_DtoH( |
| self, |
| forward_batch: ForwardBatch, |
| can_run_graph: bool, |
| cuda_graph_batch: int, |
| ): |
| - if is_dp_attention_enabled(): |
| + # When DeepEP is enabled, capture() already does all_gather, so device_cache.buffer |
| + # contains data from all DP ranks. We should not slice by DP rank in this case. |
| + if is_dp_attention_enabled() and not get_moe_a2a_backend().is_deepep(): |
| local_start_pos, local_num_tokens = get_dp_local_info(forward_batch) |
| # handle with cuda graph padding |
| if can_run_graph: |
| @@ -206,6 +225,12 @@ class _RoutedExpertsCapturerReal(RoutedExpertsCapturer): |
| ].cpu() |
| |
| def capture(self, layer_id: int, topk_ids: torch.Tensor): |
| + if get_moe_a2a_backend().is_deepep(): |
| + local_topk_ids = topk_ids |
| + topk_ids = self.gather_buffer[ |
| + : local_topk_ids.size(0) * get_attention_tp_size() |
| + ] |
| + attn_tp_all_gather_into_tensor(topk_ids, local_topk_ids) |
| self.device_cache.capture_fwd_routed_experts(layer_id, topk_ids) |
| |
| def get_routed_experts( |
| |
| |
| |
| |
| @@ -442,7 +442,7 @@ class CompressedTensorsConfig(QuantizationConfig): |
| ) |
| is_static = not weight_quant.dynamic |
| |
| - return is_channel_group and input_quant_none and is_symmetric and is_static |
| + return is_channel_group and input_quant_none and is_static |
| |
| def _get_scheme_from_parts( |
| self, weight_quant: BaseModel, input_quant: BaseModel |
| |
| |
| |
| |
| @@ -30,7 +30,10 @@ from sglang.srt.layers.quantization.fp8_utils import ( |
| normalize_e4m3fn_to_e4m3fnuz, |
| ) |
| from sglang.srt.layers.quantization.gptq import gptq_marlin_moe_repack |
| -from sglang.srt.layers.quantization.marlin_utils import marlin_moe_permute_scales |
| +from sglang.srt.layers.quantization.marlin_utils import ( |
| + marlin_moe_permute_scales, |
| + moe_awq_to_marlin_zero_points |
| +) |
| from sglang.srt.layers.quantization.utils import ( |
| all_close_1d, |
| per_tensor_dequantize, |
| @@ -865,7 +868,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): |
| self.strategy = config.strategy |
| self.group_size = config.group_size |
| self.actorder = config.actorder |
| - assert config.symmetric, "Only symmetric quantization is supported for MoE" |
| + self.sym = config.symmetric |
| |
| if not ( |
| self.quant_config.quant_format == CompressionFormat.pack_quantized.value |
| @@ -920,7 +923,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): |
| |
| # In the case where we have actorder/g_idx, |
| # we do not partition the w2 scales |
| - load_full_w2 = self.actorder and self.group_size != -1 |
| + load_full_w2 = (self.actorder != 'static') and self.group_size != -1 |
| |
| if load_full_w2: |
| w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size |
| @@ -968,6 +971,32 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): |
| layer.register_parameter("w13_weight_shape", w13_weight_shape) |
| set_weight_attrs(w13_weight_shape, extra_weight_attrs) |
| |
| + # add zero param |
| + if not self.sym: |
| + w13_qzeros = torch.nn.Parameter( |
| + torch.empty( |
| + num_experts, |
| + num_groups_w13, |
| + 2 * intermediate_size_per_partition // self.packed_factor, |
| + dtype=torch.int32, |
| + ), |
| + requires_grad=False, |
| + ) |
| + layer.register_parameter("w13_weight_zero_point", w13_qzeros) |
| + set_weight_attrs(w13_qzeros, extra_weight_attrs) |
| + |
| + w2_qzeros = torch.nn.Parameter( |
| + torch.empty( |
| + num_experts, |
| + num_groups_w2, |
| + hidden_size // self.packed_factor, |
| + dtype=torch.int32, |
| + ), |
| + requires_grad=False, |
| + ) |
| + layer.register_parameter("w2_weight_zero_point", w2_qzeros) |
| + set_weight_attrs(w2_qzeros, extra_weight_attrs) |
| + |
| w13_g_idx = torch.nn.Parameter( |
| torch.empty( |
| num_experts, |
| @@ -1016,13 +1045,40 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): |
| layer.a2_scale = None |
| layer.marlin_state = GPTQMarlinState.REPACK |
| |
| + if not hasattr(layer, "_original_shapes"): |
| + layer._original_shapes = {} |
| + |
| + # Force record: these are the target GPTQ shapes for rollback. |
| + layer._original_shapes["w13_weight_packed"] = tuple(w13_weight.shape) |
| + layer._original_shapes["w13_weight_scale"] = tuple(w13_scale.shape) |
| + if not self.sym: |
| + layer._original_shapes["w13_weight_zero_point"] = w13_qzeros.shape |
| + |
| + layer._original_shapes["w2_weight_packed"] = tuple(w2_weight.shape) |
| + layer._original_shapes["w2_weight_scale"] = tuple(w2_scale.shape) |
| + if not self.sym: |
| + layer._original_shapes["w2_weight_zero_point"] = tuple(w2_qzeros.shape) |
| + |
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: |
| + # Skip if the layer is already converted to Marlin format to prevent double-packing. |
| + if getattr(layer, "is_marlin_converted", False): |
| + return |
| + |
| + if not hasattr(layer, "_original_shapes"): |
| + layer._original_shapes = {} |
| |
| def replace_tensor(name, new_t): |
| + target_attr = getattr(layer, name) |
| + |
| + # Only save if the key doesn't exist to prevent overwriting with Marlin shapes. |
| + if name not in layer._original_shapes: |
| + # This is a safety check; `create_weights` usually handles this already. |
| + layer._original_shapes[name] = tuple(target_attr.shape) |
| + |
| # It is important to use resize_() here since it ensures |
| # the same buffer is reused |
| - getattr(layer, name).resize_(new_t.shape) |
| - getattr(layer, name).copy_(new_t) |
| + target_attr.resize_(new_t.shape) |
| + target_attr.copy_(new_t) |
| del new_t |
| |
| num_experts = layer.w13_weight_g_idx.shape[0] |
| @@ -1078,7 +1134,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): |
| layer.w13_weight_packed.shape[2], |
| self.num_bits, |
| ) |
| - replace_parameter(layer, "w13_weight_packed", marlin_w13_qweight) |
| + replace_tensor("w13_weight_packed", marlin_w13_qweight) |
| marlin_w2_qweight = gptq_marlin_moe_repack( |
| layer.w2_weight_packed, |
| layer.w2_g_idx_sort_indices, |
| @@ -1086,7 +1142,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): |
| layer.w2_weight_packed.shape[2], |
| self.num_bits, |
| ) |
| - replace_parameter(layer, "w2_weight_packed", marlin_w2_qweight) |
| + replace_tensor("w2_weight_packed", marlin_w2_qweight) |
| # Repack scales |
| marlin_w13_scales = marlin_moe_permute_scales( |
| layer.w13_weight_scale, |
| @@ -1094,7 +1150,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): |
| layer.w13_weight_scale.shape[2], |
| self.group_size, |
| ) |
| - replace_parameter(layer, "w13_weight_scale", marlin_w13_scales) |
| + replace_tensor("w13_weight_scale", marlin_w13_scales) |
| |
| marlin_w2_scales = marlin_moe_permute_scales( |
| layer.w2_weight_scale, |
| @@ -1103,7 +1159,40 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): |
| layer.w2_weight_scale.shape[2], |
| self.group_size, |
| ) |
| - replace_parameter(layer, "w2_weight_scale", marlin_w2_scales) |
| + replace_tensor("w2_weight_scale", marlin_w2_scales) |
| + |
| + # Repack zero |
| + if not self.sym: |
| + marlin_w13_zp = moe_awq_to_marlin_zero_points( |
| + layer.w13_weight_zero_point, |
| + size_k=layer.w13_weight_zero_point.shape[1], |
| + size_n=layer.w13_weight_zero_point.shape[2] * self.packed_factor, |
| + num_bits=self.num_bits, |
| + ) |
| + replace_tensor("w13_weight_zero_point", marlin_w13_zp) |
| + |
| + marlin_w2_zp = moe_awq_to_marlin_zero_points( |
| + layer.w2_weight_zero_point, |
| + size_k=layer.w2_weight_zero_point.shape[1], |
| + size_n=layer.w2_weight_zero_point.shape[2] * self.packed_factor, |
| + num_bits=self.num_bits, |
| + ) |
| + replace_tensor("w2_weight_zero_point", marlin_w2_zp) |
| + |
| + layer.is_marlin_converted = True |
| + |
| + def restore_weights_before_loading(self, layer: torch.nn.Module): |
| + """Forcibly resize parameters back to their original shapes (e.g., GPTQ format) before loading weights.""" |
| + if not hasattr(layer, "_original_shapes"): |
| + return |
| + |
| + for name, orig_shape in layer._original_shapes.items(): |
| + param = getattr(layer, name, None) |
| + |
| + if param is not None and param.shape != orig_shape: |
| + param.resize_(orig_shape) |
| + |
| + layer.is_marlin_converted = False |
| |
| def create_moe_runner( |
| self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig |
| @@ -1154,6 +1243,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): |
| g_idx2=layer.w2_weight_g_idx, |
| sort_indices1=layer.w13_g_idx_sort_indices, |
| sort_indices2=layer.w2_g_idx_sort_indices, |
| + w1_zeros=layer.w13_weight_zero_point if not self.sym else None, |
| + w2_zeros=layer.w2_weight_zero_point if not self.sym else None, |
| num_bits=self.num_bits, |
| is_k_full=self.is_k_full, |
| routed_scaling_factor=self.moe_runner_config.routed_scaling_factor, |
| |
| |
| |
| |
| @@ -136,9 +136,7 @@ class RotaryEmbedding(MultiPlatformOp): |
| |
| if get_global_server_args().rl_on_policy_target is not None: |
| self._forward_method = self.forward_native |
| - self._apply_rotary_emb_wrapped = torch.compile(dynamic=True)( |
| - self._apply_rotary_emb_wrapped |
| - ) |
| + |
| self.position_cos, self.position_sin = None, None |
| |
| def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: |
| @@ -1578,6 +1576,9 @@ class MRotaryEmbedding(RotaryEmbedding): |
| key: torch.Tensor, |
| fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| + assert ( |
| + fused_set_kv_buffer_arg is None |
| + ), "fused_set_kv_buffer_arg is not supported for npu implementation" |
| # TODO: remove this when npu_mrope supports QNumHeads * QHeadSize > 4096 |
| assert ( |
| fused_set_kv_buffer_arg is None |
| |
| |
| |
| |
| @@ -108,16 +108,11 @@ class Sampler(nn.Module): |
| if return_logprob and SGLANG_RETURN_ORIGINAL_LOGPROB: |
| probs_without_temp_scaling = torch.softmax(logits, dim=-1) |
| |
| - if get_global_server_args().rl_on_policy_target is not None: |
| - logits_div_temperature = ( |
| - logits.bfloat16().div(sampling_info.temperatures).bfloat16() |
| - ) |
| - logprobs_via_logsoftmax_kernel = torch.log_softmax( |
| - logits_div_temperature, dim=-1 |
| - ) |
| - |
| # Post process logits |
| logits.div_(sampling_info.temperatures) |
| + if get_global_server_args().rl_on_policy_target is not None: |
| + logprobs_via_logsoftmax_kernel = torch.log_softmax(logits, dim=-1) |
| + |
| # For ascend backend, softmax is not needed before sampling |
| if not get_global_server_args().sampling_backend == "ascend" or ( |
| return_logprob and not SGLANG_RETURN_ORIGINAL_LOGPROB |
| |
| |
| |
| |
| @@ -1292,6 +1292,19 @@ class UpdateWeightsFromIPCReqOutput(BaseReq): |
| success: bool |
| message: str |
| |
| +@dataclass |
| +class PostProcessWeightsReqInput(BaseReq): |
| + # Whether to restore weights before loading new weights |
| + restore_weights_before_load: bool = False |
| + # Whether to enable quantization post-processing |
| + post_process_quantization: bool = False |
| + |
| + |
| +@dataclass |
| +class PostProcessWeightsReqOutput(BaseReq): |
| + success: bool |
| + message: str |
| + |
| |
| @dataclass |
| class InitWeightsSendGroupForRemoteInstanceReqOutput(BaseReq): |
| @@ -1667,6 +1680,10 @@ class GetLoadReqOutput(BaseReq): |
| num_waiting_reqs: int |
| num_tokens: int |
| ts_tic: float |
| + # Per-queue breakdown: list of {name, num_reqs, num_tokens, reqs: [{rid, seqlen, input_len, output_len}]} |
| + queue_details: Optional[List[Dict[str, Any]]] = None |
| + # Running batch info |
| + running_details: Optional[Dict[str, Any]] = None |
| |
| |
| @dataclass |
| |
| |
| |
| |
| @@ -1779,7 +1779,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): |
| selected_indices=sorted_indices, buf_multiplier=buf_multiplier |
| ) |
| ): |
| - if len(sorted_indices) == 1: |
| + # We should allow all requests to be retracted in decode disaggregation mode |
| + # because there call be prealloc prefill requests. |
| + num_minimum_reqs = 0 if server_args.disaggregation_mode == "decode" else 1 |
| + if len(sorted_indices) == num_minimum_reqs: |
| # Always keep at least one request |
| break |
| |
| |
| |
| |
| |
| @@ -98,6 +98,7 @@ from sglang.srt.managers.io_struct import ( |
| OpenSessionReqInput, |
| OpenSessionReqOutput, |
| PauseGenerationReqInput, |
| + PostProcessWeightsReqInput, |
| ProfileReq, |
| ReleaseMemoryOccupationReqInput, |
| ResumeMemoryOccupationReqInput, |
| @@ -1060,6 +1061,7 @@ class Scheduler( |
| ), |
| (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor), |
| (UpdateWeightsFromIPCReqInput, self.update_weights_from_ipc), |
| + (PostProcessWeightsReqInput, self.post_process_weights), |
| (GetWeightsByNameReqInput, self.get_weights_by_name), |
| (ReleaseMemoryOccupationReqInput, self.release_memory_occupation), |
| (ResumeMemoryOccupationReqInput, self.resume_memory_occupation), |
| |
| |
| |
| |
| @@ -553,12 +553,48 @@ class SchedulerMetricsMixin: |
| num_tokens += sum(req.seqlen for queue in waiting_queues for req in queue) |
| num_waiting_reqs = sum(len(queue) for queue in waiting_queues) |
| |
| + # Collect per-queue details |
| + queue_names = ["waiting_queue"] |
| + if self.disaggregation_mode == DisaggregationMode.PREFILL: |
| + queue_names.append("bootstrap_queue") |
| + elif self.disaggregation_mode == DisaggregationMode.DECODE: |
| + queue_names.append("prealloc_queue") |
| + queue_names.append("transfer_queue") |
| + queue_names.append("retracted_queue") |
| + |
| + queue_details = [] |
| + for name, queue in zip(queue_names, waiting_queues): |
| + reqs_info = [] |
| + for req in queue: |
| + reqs_info.append({ |
| + "seqlen": req.seqlen, |
| + }) |
| + queue_details.append({ |
| + "name": name, |
| + "num_reqs": len(queue), |
| + "num_tokens": sum(r["seqlen"] for r in reqs_info), |
| + "reqs": reqs_info, |
| + }) |
| + |
| + # Collect running batch details |
| + running_reqs_info = [] |
| + for req in self.running_batch.reqs: |
| + running_reqs_info.append({ |
| + "seqlen": req.seqlen, |
| + }) |
| + running_details = { |
| + "num_reqs": len(self.running_batch.reqs), |
| + "reqs": running_reqs_info, |
| + } |
| + |
| return GetLoadReqOutput( |
| dp_rank=self.dp_rank, |
| num_reqs=len(self.running_batch.reqs) + num_waiting_reqs, |
| num_waiting_reqs=num_waiting_reqs, |
| num_tokens=num_tokens, |
| ts_tic=time.perf_counter(), |
| + queue_details=queue_details, |
| + running_details=running_details, |
| ) |
| |
| @contextmanager |
| |
| |
| |
| |
| @@ -10,6 +10,7 @@ from sglang.srt.disaggregation.utils import DisaggregationMode |
| from sglang.srt.environ import envs |
| from sglang.srt.layers.logits_processor import LogitsProcessorOutput |
| from sglang.srt.layers.moe.routed_experts_capturer import get_global_experts_capturer |
| + |
| from sglang.srt.managers.io_struct import ( |
| AbortReq, |
| BatchEmbeddingOutput, |
| @@ -1070,7 +1071,7 @@ class SchedulerOutputProcessorMixin: |
| req.log_time_stats() |
| |
| # Send to detokenizer |
| - if reqs or is_idle_batch: |
| + if rids or is_idle_batch: |
| if self.model_config.is_multimodal_gen: |
| return |
| |
| |
| |
| |
| |
| @@ -1,6 +1,7 @@ |
| from __future__ import annotations |
| |
| import logging |
| +import os |
| import traceback |
| from typing import TYPE_CHECKING, Tuple |
| |
| @@ -12,6 +13,9 @@ from sglang.srt.constants import ( |
| GPU_MEMORY_TYPE_KV_CACHE, |
| GPU_MEMORY_TYPE_WEIGHTS, |
| ) |
| +from sglang.srt.disaggregation.utils import DisaggregationMode |
| +from sglang.srt.distributed import get_moe_ep_group, get_moe_tp_group, get_tp_group |
| +from sglang.srt.layers.dp_attention import get_attention_tp_group |
| from sglang.srt.managers.io_struct import ( |
| CheckWeightsReqInput, |
| CheckWeightsReqOutput, |
| @@ -21,6 +25,8 @@ from sglang.srt.managers.io_struct import ( |
| GetWeightsByNameReqOutput, |
| InitWeightsUpdateGroupReqInput, |
| InitWeightsUpdateGroupReqOutput, |
| + PostProcessWeightsReqInput, |
| + PostProcessWeightsReqOutput, |
| ReleaseMemoryOccupationReqInput, |
| ReleaseMemoryOccupationReqOutput, |
| ResumeMemoryOccupationReqInput, |
| @@ -114,6 +120,11 @@ class SchedulerUpdateWeightsMixin: |
| torch.distributed.barrier(group=self.tp_cpu_group) |
| return UpdateWeightsFromIPCReqOutput(success, message) |
| |
| + def post_process_weights(self, recv_req: PostProcessWeightsReqInput): |
| + """Optional post-processing for updated weights (e.g., Marlin conversion).""" |
| + success, message = self.tp_worker.post_process_weights(recv_req) |
| + return PostProcessWeightsReqOutput(success, message) |
| + |
| def get_weights_by_name(self: Scheduler, recv_req: GetWeightsByNameReqInput): |
| parameter = self.tp_worker.get_weights_by_name(recv_req) |
| return GetWeightsByNameReqOutput(parameter) |
| @@ -137,6 +148,15 @@ class SchedulerUpdateWeightsMixin: |
| self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE) |
| self.flush_cache() |
| |
| + if self.disaggregation_mode == DisaggregationMode.DECODE: |
| + if hasattr(self, "disagg_decode_transfer_queue"): |
| + self.disagg_decode_transfer_queue.release_memory_occupation() |
| + if hasattr(self, "disagg_decode_prealloc_queue"): |
| + self.disagg_decode_prealloc_queue.release_memory_occupation() |
| + elif self.disaggregation_mode == DisaggregationMode.PREFILL: |
| + if hasattr(self, "disagg_prefill_bootstrap_queue"): |
| + self.disagg_prefill_bootstrap_queue.release_memory_occupation() |
| + |
| if GPU_MEMORY_TYPE_WEIGHTS in tags: |
| self.stashed_model_static_state = _export_static_state( |
| self.tp_worker.model_runner.model |
| @@ -177,6 +197,15 @@ class SchedulerUpdateWeightsMixin: |
| if GPU_MEMORY_TYPE_KV_CACHE in tags: |
| self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE) |
| |
| + if self.disaggregation_mode == DisaggregationMode.DECODE: |
| + if hasattr(self, "disagg_decode_transfer_queue"): |
| + self.disagg_decode_transfer_queue.resume_memory_occupation() |
| + if hasattr(self, "disagg_decode_prealloc_queue"): |
| + self.disagg_decode_prealloc_queue.resume_memory_occupation() |
| + elif self.disaggregation_mode == DisaggregationMode.PREFILL: |
| + if hasattr(self, "disagg_prefill_bootstrap_queue"): |
| + self.disagg_prefill_bootstrap_queue.resume_memory_occupation() |
| + |
| return ResumeMemoryOccupationReqOutput() |
| |
| def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput): |
| |
| |
| |
| |
| @@ -49,6 +49,8 @@ from sglang.srt.managers.io_struct import ( |
| LoadLoRAAdapterReqOutput, |
| LoRAUpdateOutput, |
| OpenSessionReqInput, |
| + PostProcessWeightsReqInput, |
| + PostProcessWeightsReqOutput, |
| ProfileReq, |
| ProfileReqOutput, |
| ProfileReqType, |
| @@ -177,6 +179,9 @@ class TokenizerCommunicatorMixin: |
| self.update_weights_from_ipc_communicator = _Communicator( |
| self.send_to_scheduler, server_args.dp_size |
| ) |
| + self.post_process_weights_communicator = _Communicator( |
| + self.send_to_scheduler, server_args.dp_size |
| + ) |
| self.get_weights_by_name_communicator = _Communicator( |
| self.send_to_scheduler, server_args.dp_size |
| ) |
| @@ -250,6 +255,10 @@ class TokenizerCommunicatorMixin: |
| UpdateWeightsFromIPCReqOutput, |
| self.update_weights_from_ipc_communicator.handle_recv, |
| ), |
| + ( |
| + PostProcessWeightsReqOutput, |
| + self.post_process_weights_communicator.handle_recv, |
| + ), |
| ( |
| GetWeightsByNameReqOutput, |
| self.get_weights_by_name_communicator.handle_recv, |
| @@ -433,6 +442,17 @@ class TokenizerCommunicatorMixin: |
| |
| return success, message |
| |
| + async def post_process_weights( |
| + self: TokenizerManager, |
| + obj: PostProcessWeightsReqInput, |
| + request: Optional[fastapi.Request] = None, |
| + ) -> Tuple[bool, str]: |
| + """Trigger post-processing hooks for weights after loading (e.g., Marlin conversion).""" |
| + self.auto_create_handle_loop() |
| + async with self.model_update_lock.writer_lock: |
| + results = await self.post_process_weights_communicator(obj) |
| + return _Communicator.merge_results(results) |
| + |
| async def init_weights_send_group_for_remote_instance( |
| self, |
| obj: InitWeightsSendGroupForRemoteInstanceReqInput, |
| |
| |
| |
| |
| @@ -27,6 +27,7 @@ from sglang.srt.managers.io_struct import ( |
| InitWeightsSendGroupForRemoteInstanceReqInput, |
| InitWeightsUpdateGroupReqInput, |
| LoadLoRAAdapterReqInput, |
| + PostProcessWeightsReqInput, |
| SendWeightsToRemoteInstanceReqInput, |
| UnloadLoRAAdapterReqInput, |
| UpdateWeightFromDiskReqInput, |
| @@ -175,6 +176,11 @@ class BaseTpWorker(ABC): |
| success, message = self.model_runner.update_weights_from_ipc(recv_req) |
| return success, message |
| |
| + def post_process_weights(self, recv_req: PostProcessWeightsReqInput): |
| + """Perform optional post-processing on the updated model weights (e.g., Marlin conversion).""" |
| + success, message = self.model_runner.post_process_weights(recv_req) |
| + return success, message |
| + |
| def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): |
| parameter = self.model_runner.get_weights_by_name( |
| recv_req.name, recv_req.truncate_size |
| |
| |
| |
| |
| @@ -287,6 +287,85 @@ def alloc_decode_kernel( |
| tl.store(out_indices + pid, page * page_size) |
| |
| |
| +def alloc_extend_torch_fallback( |
| + prefix_lens_cpu: torch.Tensor, |
| + seq_lens_cpu: torch.Tensor, |
| + last_loc: torch.Tensor, |
| + free_pages: torch.Tensor, |
| + out_indices: torch.Tensor, |
| + page_size: int, |
| + debug_mode: bool = False, |
| +): |
| + extend_lens_cpu = (seq_lens_cpu - prefix_lens_cpu).to(torch.int64) |
| + if extend_lens_cpu.numel() == 0: |
| + return |
| + |
| + output_start_locs_cpu = torch.cumsum(extend_lens_cpu, dim=0) - extend_lens_cpu |
| + num_pages_after = (seq_lens_cpu + page_size - 1) // page_size |
| + num_pages_before = (prefix_lens_cpu + page_size - 1) // page_size |
| + num_new_pages_cpu = num_pages_after - num_pages_before |
| + page_start_locs_cpu = torch.cumsum(num_new_pages_cpu, dim=0) - num_new_pages_cpu |
| + |
| + total_new_pages = int(num_new_pages_cpu.sum().item()) |
| + if total_new_pages > free_pages.numel(): |
| + return |
| + |
| + if debug_mode: |
| + assert int(extend_lens_cpu.sum().item()) == out_indices.numel() |
| + |
| + prefix_lens_list = prefix_lens_cpu.tolist() |
| + seq_lens_list = seq_lens_cpu.tolist() |
| + extend_lens_list = extend_lens_cpu.tolist() |
| + out_start_list = output_start_locs_cpu.tolist() |
| + page_start_list = page_start_locs_cpu.tolist() |
| + num_new_pages_list = num_new_pages_cpu.tolist() |
| + |
| + device = out_indices.device |
| + dtype = out_indices.dtype |
| + offsets_page = torch.arange(page_size, device=device, dtype=dtype) |
| + |
| + for i, extend_len in enumerate(extend_lens_list): |
| + if extend_len == 0: |
| + continue |
| + |
| + pre_len = prefix_lens_list[i] |
| + seq_len = seq_lens_list[i] |
| + out_start = out_start_list[i] |
| + page_start = page_start_list[i] |
| + num_new_pages = num_new_pages_list[i] |
| + |
| + pre_mod = pre_len % page_size |
| + part1 = min(extend_len, page_size - pre_mod) if pre_mod != 0 else 0 |
| + if part1: |
| + start_val = last_loc[i] + 1 |
| + out_indices[out_start : out_start + part1] = start_val + torch.arange( |
| + part1, device=device, dtype=dtype |
| + ) |
| + if part1 == extend_len: |
| + continue |
| + |
| + ceil_pre_pages = (pre_len + page_size - 1) // page_size |
| + full_pages_after = seq_len // page_size |
| + num_full_pages = full_pages_after - ceil_pre_pages |
| + if num_full_pages < 0: |
| + num_full_pages = 0 |
| + part2 = num_full_pages * page_size |
| + if part2: |
| + pages = free_pages[page_start : page_start + num_full_pages] |
| + full_indices = (pages[:, None] * page_size + offsets_page).reshape(-1) |
| + out_indices[out_start + part1 : out_start + part1 + part2] = full_indices |
| + if part1 + part2 == extend_len: |
| + continue |
| + |
| + part3 = extend_len - part1 - part2 |
| + if part3: |
| + last_page = free_pages[page_start + num_new_pages - 1] |
| + out_indices[out_start + part1 + part2 : out_start + extend_len] = ( |
| + last_page * page_size |
| + + torch.arange(part3, device=device, dtype=dtype) |
| + ) |
| + |
| + |
| class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): |
| """ |
| An allocator managing the indices to kv cache data. |
| @@ -349,11 +428,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): |
| (last_loc + 1) % self.page_size == prefix_lens % self.page_size |
| ) |
| |
| - self.seen_max_num_extend_tokens_next_power_of_2 = max( |
| - self.seen_max_num_extend_tokens_next_power_of_2, |
| - next_power_of_2(extend_num_tokens), |
| - ) |
| - |
| bs = len(prefix_lens) |
| if self.need_sort and extend_num_tokens // self.page_size + bs + 1 > len( |
| self.free_pages |
| @@ -363,16 +437,34 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): |
| out_indices = torch.empty( |
| (extend_num_tokens,), dtype=torch.int64, device=self.device |
| ) |
| - alloc_extend_kernel[(bs,)]( |
| - prefix_lens, |
| - seq_lens, |
| - last_loc, |
| - self.free_pages, |
| - out_indices, |
| - next_power_of_2(bs), |
| - self.page_size, |
| - self.seen_max_num_extend_tokens_next_power_of_2, |
| - ) |
| + |
| + # Use PyTorch fallback for large extend_num_tokens to avoid slow Triton compilation |
| + MAX_TRITON_EXTEND_TOKENS = 65536 # 64K |
| + if next_power_of_2(extend_num_tokens) > MAX_TRITON_EXTEND_TOKENS: |
| + alloc_extend_torch_fallback( |
| + prefix_lens_cpu=prefix_lens_cpu, |
| + seq_lens_cpu=seq_lens_cpu, |
| + last_loc=last_loc, |
| + free_pages=self.free_pages, |
| + out_indices=out_indices, |
| + page_size=self.page_size, |
| + debug_mode=self.debug_mode, |
| + ) |
| + else: |
| + self.seen_max_num_extend_tokens_next_power_of_2 = max( |
| + self.seen_max_num_extend_tokens_next_power_of_2, |
| + next_power_of_2(extend_num_tokens), |
| + ) |
| + alloc_extend_kernel[(bs,)]( |
| + prefix_lens, |
| + seq_lens, |
| + last_loc, |
| + self.free_pages, |
| + out_indices, |
| + next_power_of_2(bs), |
| + self.page_size, |
| + self.seen_max_num_extend_tokens_next_power_of_2, |
| + ) |
| |
| if self.debug_mode: |
| assert len(torch.unique(out_indices)) == len(out_indices) |
| |
| |
| |
| |
| @@ -11,10 +11,15 @@ import torch |
| |
| from sglang.srt.managers.cache_controller import HiCacheController, PrefetchOperation |
| from sglang.srt.mem_cache.base_prefix_cache import MatchResult |
| -from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool |
| +from sglang.srt.mem_cache.memory_pool import ( |
| + MHATokenToKVPool, |
| + MLATokenToKVPool, |
| + NSATokenToKVPool, |
| +) |
| from sglang.srt.mem_cache.memory_pool_host import ( |
| MHATokenToKVPoolHost, |
| MLATokenToKVPoolHost, |
| + NSATokenToKVPoolHost, |
| ) |
| from sglang.srt.mem_cache.radix_cache import ( |
| RadixCache, |
| @@ -54,6 +59,16 @@ class HiRadixCache(RadixCache): |
| server_args.hicache_mem_layout, |
| allocator_type=server_args.hicache_storage_backend, |
| ) |
| + elif isinstance(self.kv_cache, NSATokenToKVPool): |
| + # Check NSA before MLA since NSATokenToKVPool is a subclass of MLATokenToKVPool |
| + self.token_to_kv_pool_host = NSATokenToKVPoolHost( |
| + self.kv_cache, |
| + server_args.hicache_ratio, |
| + server_args.hicache_size, |
| + self.page_size, |
| + server_args.hicache_mem_layout, |
| + allocator_type=server_args.hicache_storage_backend, |
| + ) |
| elif isinstance(self.kv_cache, MLATokenToKVPool): |
| self.token_to_kv_pool_host = MLATokenToKVPoolHost( |
| self.kv_cache, |
| @@ -64,7 +79,7 @@ class HiRadixCache(RadixCache): |
| allocator_type=server_args.hicache_storage_backend, |
| ) |
| else: |
| - raise ValueError(f"HiRadixCache only supports MHA and MLA yet") |
| + raise ValueError(f"HiRadixCache only supports MHA and MLA and NSA yet") |
| |
| self.tp_group = params.tp_cache_group |
| self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group) |
| |
| |
| |
| |
| @@ -1678,7 +1678,8 @@ class NSATokenToKVPool(MLATokenToKVPool): |
| with ( |
| torch.cuda.use_mem_pool(self.custom_mem_pool) |
| if self.custom_mem_pool |
| - else nullcontext() |
| + else nullcontext(), |
| + self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE), |
| ): |
| self.index_k_with_scale_buffer = [ |
| torch.zeros( |
| @@ -1700,6 +1701,11 @@ class NSATokenToKVPool(MLATokenToKVPool): |
| ) |
| for _ in range(layer_num) |
| ] |
| + self.index_k_with_scale_buffer_ptrs = torch.tensor( |
| + [x.data_ptr() for x in self.index_k_with_scale_buffer], |
| + dtype=torch.uint64, |
| + device=self.device, |
| + ) |
| self._finalize_allocation_log(size) |
| |
| def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor: |
| @@ -1775,6 +1781,50 @@ class NSATokenToKVPool(MLATokenToKVPool): |
| ] |
| return data_ptrs, data_lens, item_lens |
| |
| + def get_cpu_copy(self, indices): |
| + # First, save the kv_buffer (inherited from MLATokenToKVPool) |
| + kv_cache_cpu = super().get_cpu_copy(indices) |
| + |
| + # Additionally, save the index_k_with_scale_buffer (page-indexed) |
| + page_indices = indices[:: self.page_size] // self.page_size |
| + torch.cuda.synchronize() |
| + index_k_cpu = [] |
| + chunk_size = self.cpu_offloading_chunk_size |
| + # Convert chunk_size from token-level to page-level |
| + page_chunk_size = max(1, chunk_size // self.page_size) |
| + for layer_id in range(self.layer_num): |
| + index_k_cpu.append([]) |
| + for i in range(0, len(page_indices), page_chunk_size): |
| + chunk_page_indices = page_indices[i : i + page_chunk_size] |
| + idx_cpu = self.index_k_with_scale_buffer[layer_id][ |
| + chunk_page_indices |
| + ].to("cpu", non_blocking=True) |
| + index_k_cpu[-1].append(idx_cpu) |
| + torch.cuda.synchronize() |
| + |
| + return {"kv": kv_cache_cpu, "index_k": index_k_cpu} |
| + |
| + def load_cpu_copy(self, kv_cache_cpu_dict, indices): |
| + # Restore the kv_buffer (inherited from MLATokenToKVPool) |
| + super().load_cpu_copy(kv_cache_cpu_dict["kv"], indices) |
| + |
| + # Restore the index_k_with_scale_buffer (page-indexed) |
| + page_indices = indices[:: self.page_size] // self.page_size |
| + index_k_cpu = kv_cache_cpu_dict["index_k"] |
| + torch.cuda.synchronize() |
| + chunk_size = self.cpu_offloading_chunk_size |
| + page_chunk_size = max(1, chunk_size // self.page_size) |
| + for layer_id in range(self.layer_num): |
| + for i in range(0, len(page_indices), page_chunk_size): |
| + chunk_page_indices = page_indices[i : i + page_chunk_size] |
| + idx_cpu = index_k_cpu[layer_id][i // page_chunk_size] |
| + assert idx_cpu.shape[0] == len(chunk_page_indices) |
| + idx_chunk = idx_cpu.to( |
| + self.index_k_with_scale_buffer[0].device, non_blocking=True |
| + ) |
| + self.index_k_with_scale_buffer[layer_id][chunk_page_indices] = idx_chunk |
| + torch.cuda.synchronize() |
| + |
| def get_kv_size_bytes(self): |
| kv_size_bytes = super().get_kv_size_bytes() |
| for index_k_cache in self.index_k_with_scale_buffer: |
| |
| |
| |
| |
| @@ -15,7 +15,12 @@ from sglang.jit_kernel.hicache import ( |
| from sglang.jit_kernel.hicache import ( |
| transfer_hicache_one_layer as jit_transfer_hicache_one_layer, |
| ) |
| -from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool |
| +from sglang.srt.mem_cache.memory_pool import ( |
| + KVCache, |
| + MHATokenToKVPool, |
| + MLATokenToKVPool, |
| + NSATokenToKVPool, |
| +) |
| from sglang.srt.utils import is_cuda, is_npu, is_xpu |
| |
| _is_cuda = is_cuda() |
| @@ -1015,3 +1020,199 @@ class MLATokenToKVPoolHost(HostKVCache): |
| else: |
| raise ValueError(f"Unsupported layout: {self.layout}") |
| return ptr_list, element_size_list |
| + |
| + |
| +class NSATokenToKVPoolHost(MLATokenToKVPoolHost): |
| + """ |
| + Host memory pool for NSA (Native Sparse Attention) KV cache. |
| + |
| + NSA extends MLA with an additional index_k_with_scale_buffer that stores |
| + sparse attention indexing information. This class ensures that buffer is |
| + also backed up and restored during hicache operations. |
| + """ |
| + |
| + device_pool: NSATokenToKVPool |
| + |
| + def __init__( |
| + self, |
| + device_pool: NSATokenToKVPool, |
| + host_to_device_ratio: float, |
| + host_size: int, |
| + page_size: int, |
| + layout: str, |
| + pin_memory: bool = True, |
| + device: str = "cpu", |
| + allocator_type: str = "default", |
| + ): |
| + # Store NSA-specific attributes before calling parent __init__ |
| + self.index_head_dim = device_pool.index_head_dim |
| + self.quant_block_size = device_pool.quant_block_size |
| + self.index_k_with_scale_buffer_dtype = device_pool.index_k_with_scale_buffer_dtype |
| + |
| + super().__init__( |
| + device_pool, |
| + host_to_device_ratio, |
| + host_size, |
| + page_size, |
| + layout, |
| + pin_memory, |
| + device, |
| + allocator_type, |
| + ) |
| + |
| + # Initialize index buffer references and pointers for efficient transfer |
| + self.index_data_refs = [ |
| + self.index_k_with_scale_buffer[i] for i in range(self.layer_num) |
| + ] |
| + self.index_data_ptrs = torch.tensor( |
| + [x.data_ptr() for x in self.index_data_refs], |
| + dtype=torch.uint64, |
| + device=self.device_pool.device, |
| + ) |
| + |
| + def get_size_per_token(self): |
| + # Get base MLA size |
| + base_size = super().get_size_per_token() |
| + |
| + # Add NSA index buffer size per token |
| + # index_k_with_scale_buffer shape per layer: (num_pages, page_size * (index_head_dim + index_head_dim // quant_block_size * 4)) |
| + # Per token: (index_head_dim + index_head_dim // quant_block_size * 4) * dtype.itemsize * layer_num |
| + index_size_per_token = ( |
| + (self.index_head_dim + self.index_head_dim // self.quant_block_size * 4) |
| + * self.index_k_with_scale_buffer_dtype.itemsize |
| + * self.layer_num |
| + ) |
| + |
| + return base_size + index_size_per_token |
| + |
| + def init_kv_buffer(self): |
| + # Initialize base MLA kv_buffer |
| + buffer = super().init_kv_buffer() |
| + |
| + # Initialize NSA index_k_with_scale_buffer on host |
| + # Layout matches device pool: (num_pages, page_size * (index_head_dim + index_head_dim // quant_block_size * 4)) |
| + index_buffer_second_dim = self.page_size * ( |
| + self.index_head_dim + self.index_head_dim // self.quant_block_size * 4 |
| + ) |
| + self.index_stride_size = (self.index_head_dim + self.index_head_dim // self.quant_block_size * 4) * self.index_k_with_scale_buffer_dtype.itemsize |
| + |
| + alloc_func = ALLOC_MEMORY_FUNCS[self.device_pool.device] |
| + self.index_k_with_scale_buffer = [ |
| + alloc_func( |
| + (self.page_num, index_buffer_second_dim), |
| + dtype=self.index_k_with_scale_buffer_dtype, |
| + device=self.device, |
| + pin_memory=self.pin_memory, |
| + allocator=self.allocator, |
| + ) |
| + for _ in range(self.layer_num) |
| + ] |
| + |
| + return buffer |
| + |
| + def _load_indexer_to_device_per_layer( |
| + self, device_pool, host_indices, device_indices, layer_id, io_backend |
| + ): |
| + """Load index_k_with_scale_buffer from host to device for a specific layer.""" |
| + # Convert token indices to page indices |
| + # host_indices and device_indices are token-level indices |
| + # index_k_with_scale_buffer is page-level with shape (num_pages, page_size * dim) |
| + |
| + if io_backend == "kernel": |
| + # Use page-level copy for index buffer |
| + # Calculate page indices from token indices |
| + page_indices_host = host_indices[:: self.page_size] // self.page_size |
| + page_indices_device = device_indices[:: self.page_size] // self.page_size |
| + |
| + src_buffer = self.index_k_with_scale_buffer[layer_id] |
| + dst_buffer = device_pool.index_k_with_scale_buffer[layer_id - device_pool.start_layer] |
| + |
| + # Copy each page |
| + # for i in range(len(page_indices_host)): |
| + # src_page_idx = page_indices_host[i].item() |
| + # dst_page_idx = page_indices_device[i].item() |
| + # dst_buffer[dst_page_idx].copy_(src_buffer[src_page_idx], non_blocking=True) |
| + if self.layout == "layer_first": |
| + transfer_kv_per_layer_mla( |
| + src=src_buffer, |
| + dst=dst_buffer, |
| + src_indices=page_indices_host, |
| + dst_indices=page_indices_device, |
| + item_size=self.index_stride_size * self.page_size, |
| + ) |
| + else: |
| + raise ValueError(f"Unsupported layout: {self.layout}") |
| + |
| + elif io_backend == "direct": |
| + # Direct I/O copy for index buffer |
| + page_indices_host = host_indices[:: self.page_size] // self.page_size |
| + page_indices_device = device_indices[:: self.page_size] // self.page_size |
| + |
| + src_buffer = self.index_k_with_scale_buffer[layer_id] |
| + dst_buffer = device_pool.index_k_with_scale_buffer[layer_id - device_pool.start_layer] |
| + |
| + for i in range(len(page_indices_host)): |
| + src_page_idx = page_indices_host[i].item() |
| + dst_page_idx = page_indices_device[i].item() |
| + dst_buffer[dst_page_idx].copy_(src_buffer[src_page_idx], non_blocking=True) |
| + else: |
| + raise ValueError(f"Unsupported IO backend for NSA indexer: {io_backend}") |
| + |
| + def _backup_indexer_from_device_all_layer( |
| + self, device_pool, host_indices, device_indices, io_backend |
| + ): |
| + """Backup index_k_with_scale_buffer from device to host for all layers.""" |
| + # Convert token indices to page indices |
| + page_indices_host = host_indices[:: self.page_size] // self.page_size |
| + page_indices_device = device_indices[:: self.page_size] // self.page_size |
| + |
| + # if io_backend in ["kernel", "direct"]: |
| + if io_backend == "kernel": |
| + if self.layout == "layer_first": |
| + transfer_kv_all_layer_mla( |
| + src_layers=device_pool.index_k_with_scale_buffer_ptrs, |
| + dst_layers=self.index_data_ptrs, |
| + src_indices=page_indices_device, |
| + dst_indices=page_indices_host, |
| + item_size=self.index_stride_size * self.page_size, |
| + num_layers=self.layer_num, |
| + ) |
| + else: |
| + raise ValueError(f"Unsupported layout: {self.layout}") |
| + elif io_backend == "direct": |
| + for layer_id in range(self.layer_num): |
| + src_buffer = device_pool.index_k_with_scale_buffer[layer_id] |
| + dst_buffer = self.index_k_with_scale_buffer[layer_id] |
| + |
| + for i in range(len(page_indices_device)): |
| + src_page_idx = page_indices_device[i].item() |
| + dst_page_idx = page_indices_host[i].item() |
| + dst_buffer[dst_page_idx].copy_(src_buffer[src_page_idx], non_blocking=True) |
| + else: |
| + raise ValueError(f"Unsupported IO backend for NSA indexer: {io_backend}") |
| + |
| + def load_to_device_per_layer( |
| + self, device_pool, host_indices, device_indices, layer_id, io_backend |
| + ): |
| + """Load KV cache and index buffer from host to device for a specific layer.""" |
| + # Load base MLA kv_buffer |
| + super().load_to_device_per_layer( |
| + device_pool, host_indices, device_indices, layer_id, io_backend |
| + ) |
| + # Load NSA index_k_with_scale_buffer |
| + self._load_indexer_to_device_per_layer( |
| + device_pool, host_indices, device_indices, layer_id, io_backend |
| + ) |
| + |
| + def backup_from_device_all_layer( |
| + self, device_pool, host_indices, device_indices, io_backend |
| + ): |
| + """Backup KV cache and index buffer from device to host for all layers.""" |
| + # Backup base MLA kv_buffer |
| + super().backup_from_device_all_layer( |
| + device_pool, host_indices, device_indices, io_backend |
| + ) |
| + # Backup NSA index_k_with_scale_buffer |
| + self._backup_indexer_from_device_all_layer( |
| + device_pool, host_indices, device_indices, io_backend |
| + ) |
| |
| |
| |
| |
| @@ -558,7 +558,8 @@ class ModelRunner(ModelRunnerKVCacheMixin): |
| ) |
| |
| # Init routed experts capturer |
| - self.init_routed_experts_capturer() |
| + if not self.is_draft_worker: |
| + self.init_routed_experts_capturer() |
| |
| if self.device == "cuda": |
| self.init_cublas() |
| @@ -2224,11 +2225,19 @@ class ModelRunner(ModelRunnerKVCacheMixin): |
| output.expert_distribution_metrics = recorder_outputs.get("metrics") |
| |
| # Copy cached routing experts' buffers back to CPU cache |
| - get_global_experts_capturer().on_forward_end( |
| - forward_batch=forward_batch, |
| - can_run_graph=output.can_run_graph, |
| - cuda_graph_batch=getattr(self.graph_runner, "bs", None), |
| - ) |
| + if not self.is_draft_worker: |
| + # In speculative decoding, num_tokens_per_bs > 1, so we need to pass |
| + # the actual number of tokens per dp rank in cuda graph, not batch size. |
| + cuda_graph_num_tokens = None |
| + if getattr(self.graph_runner, "bs", None): |
| + cuda_graph_num_tokens = ( |
| + self.graph_runner.bs * self.graph_runner.num_tokens_per_bs |
| + ) |
| + get_global_experts_capturer().on_forward_end( |
| + forward_batch=forward_batch, |
| + can_run_graph=output.can_run_graph, |
| + cuda_graph_batch=cuda_graph_num_tokens, |
| + ) |
| |
| if self.eplb_manager is not None: |
| self.eplb_manager.on_forward_pass_end() |
| @@ -2436,6 +2445,41 @@ class ModelRunner(ModelRunnerKVCacheMixin): |
| logger.error(f"IPC weight update failed: {e}") |
| return False, str(e) |
| |
| + def post_process_weights(self, recv_req): |
| + """ |
| + Execute post-processing logic for model weights, such as Marlin quantization format conversion. |
| + """ |
| + from sglang.srt.model_loader.loader import device_loading_context |
| + |
| + target_device = torch.device("cuda", torch.cuda.current_device()) |
| + |
| + if recv_req.restore_weights_before_load: |
| + for _, module in self.model.named_modules(): |
| + quant_method = getattr(module, "quant_method", None) |
| + |
| + # Check if the module supports restoring weights |
| + if quant_method is not None and hasattr( |
| + quant_method, "restore_weights_before_loading" |
| + ): |
| + |
| + with device_loading_context(module, target_device): |
| + quant_method.restore_weights_before_loading(module) |
| + |
| + if recv_req.post_process_quantization: |
| + # Iterate through all modules to apply specific post-loading processing |
| + for _, module in self.model.named_modules(): |
| + quant_method = getattr(module, "quant_method", None) |
| + |
| + # Check if the module supports quantization post-processing |
| + if quant_method is not None and hasattr( |
| + quant_method, "process_weights_after_loading" |
| + ): |
| + |
| + # Apply the post-processing (e.g., repacking weights for Marlin kernel) |
| + with device_loading_context(module, target_device): |
| + quant_method.process_weights_after_loading(module) |
| + |
| + return True, "Success" |
| |
| def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]): |
| params_dict = dict(model.named_parameters()) |
| |
| |
| |
| |
| @@ -159,6 +159,7 @@ from sglang.srt.utils import ( |
| make_layers, |
| use_intel_amx_backend, |
| ) |
| +from sglang.srt.layers.attention.hybrid_attn_backend import HybridAttnBackend |
| |
| _is_hip = is_hip() |
| _is_cuda = is_cuda() |
| @@ -434,6 +435,8 @@ def handle_attention_nsa(attn, forward_batch): |
| backend = forward_batch.attn_backend |
| if isinstance(backend, TboAttnBackend): # if enable tbo, get primary backend |
| backend = backend.primary |
| + if isinstance(backend, HybridAttnBackend): |
| + backend = backend._select_backend(forward_batch.forward_mode) |
| if hasattr(backend, "use_mha") and backend.use_mha: |
| return AttnForwardMethod.MHA_ONE_SHOT |
| return AttnForwardMethod.MLA |
| @@ -2704,7 +2707,11 @@ class DeepseekV2AttentionMLA(nn.Module): |
| ): |
| k = k_nope.new_empty(*k_shape) |
| concat_mla_k(k=k, k_nope=k_nope, k_rope=k_pe) |
| - elif _is_cuda: |
| + elif _is_cuda and all( |
| + # (i.bit_count() == 1) == (is_power_of_two(i)) |
| + i.bit_count() == 1 |
| + for i in (k_shape[1], k_nope.shape[-1], k_pe.shape[-1]) |
| + ): |
| # fa3 mha support fp8 inputs |
| if ( |
| self.current_attention_backend == "fa3" |
| |
| |
| |
| |
| @@ -90,9 +90,6 @@ class Qwen2MLP(nn.Module): |
| self.act_fn = SiluAndMul() |
| |
| def forward(self, x): |
| - if get_global_server_args().rl_on_policy_target is not None: |
| - x = x.bfloat16() |
| - |
| gate_up, _ = self.gate_up_proj(x) |
| x = self.act_fn(gate_up) |
| x, _ = self.down_proj(x) |
| @@ -279,11 +276,6 @@ class Qwen2Model(nn.Module): |
| quant_config=quant_config, |
| enable_tp=not is_dp_attention_enabled(), |
| prefix=add_prefix("embed_tokens", prefix), |
| - params_dtype=( |
| - torch.float32 |
| - if get_global_server_args().rl_on_policy_target is not None |
| - else None |
| - ), |
| ) |
| else: |
| self.embed_tokens = PPMissingLayer() |
| @@ -306,10 +298,8 @@ class Qwen2Model(nn.Module): |
| if self.pp_group.is_last_rank: |
| norm_kwargs = ( |
| dict( |
| - weight_dtype=torch.float32, |
| cast_x_before_out_mul=True, |
| - override_orig_dtype=torch.float32, |
| - fp32_residual=True, |
| + fp32_residual=False, |
| ) |
| if get_global_server_args().rl_on_policy_target is not None |
| else {} |
| |
| |
| |
| |
| @@ -586,7 +586,17 @@ class Qwen2MoeModel(nn.Module): |
| prefix=add_prefix("layers", prefix), |
| ) |
| if self.pp_group.is_last_rank: |
| - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| + norm_kwargs = ( |
| + dict( |
| + cast_x_before_out_mul=True, |
| + fp32_residual=False, |
| + ) |
| + if get_global_server_args().rl_on_policy_target is not None |
| + else {} |
| + ) |
| + self.norm = RMSNorm( |
| + config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs |
| + ) |
| else: |
| self.norm = PPMissingLayer(return_tuple=True) |
| |
| |
| |
| |
| |
| @@ -90,8 +90,8 @@ class Qwen3Attention(nn.Module): |
| |
| norm_kwargs = ( |
| dict( |
| - weight_dtype=torch.float32, |
| cast_x_before_out_mul=True, |
| + fp32_residual=False, |
| ) |
| if get_global_server_args().rl_on_policy_target is not None |
| else {} |
| @@ -242,10 +242,8 @@ class Qwen3DecoderLayer(nn.Module): |
| |
| norm_kwargs = ( |
| dict( |
| - weight_dtype=torch.float32, |
| cast_x_before_out_mul=True, |
| - override_orig_dtype=torch.float32, |
| - fp32_residual=True, |
| + fp32_residual=False, |
| ) |
| if get_global_server_args().rl_on_policy_target is not None |
| else {} |
| @@ -276,14 +274,14 @@ class Qwen3DecoderLayer(nn.Module): |
| hidden_states: torch.Tensor, |
| forward_batch: ForwardBatch, |
| residual: Optional[torch.Tensor], |
| - **kwargs, |
| + post_residual_addition: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| # Self Attention |
| hidden_states, residual = self.layer_communicator.prepare_attn( |
| hidden_states, |
| residual, |
| forward_batch, |
| - **kwargs, |
| + post_residual_addition=post_residual_addition, |
| ) |
| if hidden_states.shape[0] != 0: |
| hidden_states = self.self_attn( |
| |
| |
| |
| |
| @@ -22,6 +22,7 @@ import math |
| from typing import Any, Dict, Iterable, List, Optional, Tuple, TypeVar |
| |
| import torch |
| +import torch.nn.functional as F |
| from torch import nn |
| from transformers import PretrainedConfig |
| |
| @@ -50,7 +51,7 @@ from sglang.srt.layers.moe import ( |
| ) |
| from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class |
| from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE |
| -from sglang.srt.layers.moe.topk import TopK |
| +from sglang.srt.layers.moe.topk import StandardTopKOutput, TopK |
| from sglang.srt.layers.moe.utils import RoutingMethodType |
| from sglang.srt.layers.quantization.base_config import QuantizationConfig |
| from sglang.srt.layers.radix_attention import RadixAttention |
| @@ -229,6 +230,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): |
| use_grouped_topk=False, |
| layer_id=layer_id, |
| ) |
| + self.top_k = config.num_experts_per_tok |
| |
| self.experts = get_moe_impl_class(quant_config)( |
| num_experts=config.num_experts |
| @@ -294,7 +296,22 @@ class Qwen3MoeSparseMoeBlock(nn.Module): |
| |
| # router_logits: (num_tokens, n_experts) |
| router_logits, _ = self.gate(hidden_states) |
| - topk_output = self.topk(hidden_states, router_logits) |
| + |
| + if get_global_server_args().rl_on_policy_target is not None: |
| + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) |
| + routing_weights, selected_experts = torch.topk( |
| + routing_weights, self.top_k, dim=-1 |
| + ) |
| + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) |
| + routing_weights = routing_weights.to(hidden_states.dtype) |
| + topk_output = StandardTopKOutput( |
| + topk_weights=routing_weights, |
| + topk_ids=selected_experts, |
| + router_logits=router_logits, |
| + ) |
| + else: |
| + topk_output = self.topk(hidden_states, router_logits) |
| + |
| final_hidden_states = self.experts(hidden_states, topk_output) |
| if ( |
| self.tp_size > 1 |
| @@ -475,13 +492,14 @@ class Qwen3MoeAttention(nn.Module): |
| ) |
| self.compatible_with_fused_kv_buffer = ( |
| False if isinstance(self.rotary_emb, MRotaryEmbedding) else True |
| - ) |
| + ) and (get_global_server_args().rl_on_policy_target is None) |
| self.compatible_with_fused_qk_norm_rope = ( |
| not isinstance(self.rotary_emb, MRotaryEmbedding) |
| ) and self.head_dim in (64, 128, 256) |
| self.use_fused_qk_norm_rope = ( |
| get_global_server_args().enable_fused_qk_norm_rope |
| and self.compatible_with_fused_qk_norm_rope |
| + and (get_global_server_args().rl_on_policy_target is None) |
| ) |
| self._used_fused_qk_norm_rope_last_call = False |
| |
| @@ -494,8 +512,16 @@ class Qwen3MoeAttention(nn.Module): |
| prefix=add_prefix("attn", prefix), |
| ) |
| |
| - self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) |
| - self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) |
| + norm_kwargs = ( |
| + dict( |
| + cast_x_before_out_mul=True, |
| + fp32_residual=False, |
| + ) |
| + if get_global_server_args().rl_on_policy_target is not None |
| + else {} |
| + ) |
| + self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs) |
| + self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs) |
| self.alt_stream = alt_stream |
| |
| def op_prepare(self, state): |
| @@ -736,9 +762,19 @@ class Qwen3MoeDecoderLayer(nn.Module): |
| quant_config=quant_config, |
| prefix=add_prefix("mlp", prefix), |
| ) |
| - self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| + norm_kwargs = ( |
| + dict( |
| + cast_x_before_out_mul=True, |
| + fp32_residual=False, |
| + ) |
| + if get_global_server_args().rl_on_policy_target is not None |
| + else {} |
| + ) |
| + self.input_layernorm = RMSNorm( |
| + config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs |
| + ) |
| self.post_attention_layernorm = RMSNorm( |
| - config.hidden_size, eps=config.rms_norm_eps |
| + config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs |
| ) |
| |
| self.layer_communicator = LayerCommunicator( |
| |
| |
| |
| |
| @@ -397,28 +397,68 @@ class Qwen3VLMoeVisionModel(nn.Module, RotaryPosMixin): |
| return cos_combined, sin_combined |
| |
| def fast_pos_embed_interpolate(self, grid_thw): |
| - patch_pos_embeds_permute = [] |
| - m_size = self.spatial_merge_size |
| + grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] |
| + num_grid_per_side = int(self.num_position_embeddings**0.5) |
| + device = self.pos_embed.weight.device |
| + |
| + idx_list = [[] for _ in range(4)] |
| + weight_list = [[] for _ in range(4)] |
| + |
| + for t, h, w in zip(grid_ts, grid_hs, grid_ws): |
| + h_idxs = torch.linspace(0, num_grid_per_side - 1, h) |
| + w_idxs = torch.linspace(0, num_grid_per_side - 1, w) |
| + |
| + h_idxs_floor = h_idxs.int() |
| + w_idxs_floor = w_idxs.int() |
| + h_idxs_ceil = (h_idxs.int() + 1).clip(max=num_grid_per_side - 1) |
| + w_idxs_ceil = (w_idxs.int() + 1).clip(max=num_grid_per_side - 1) |
| + |
| + dh = h_idxs - h_idxs_floor |
| + dw = w_idxs - w_idxs_floor |
| + |
| + base_h = h_idxs_floor * num_grid_per_side |
| + base_h_ceil = h_idxs_ceil * num_grid_per_side |
| + |
| + indices = [ |
| + (base_h[None].T + w_idxs_floor[None]).flatten(), |
| + (base_h[None].T + w_idxs_ceil[None]).flatten(), |
| + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), |
| + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), |
| + ] |
| + |
| + weights = [ |
| + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), |
| + ((1 - dh)[None].T * dw[None]).flatten(), |
| + (dh[None].T * (1 - dw)[None]).flatten(), |
| + (dh[None].T * dw[None]).flatten(), |
| + ] |
| |
| - embeds = torch.arange(self.num_grid, device=self.pos_embed.weight.device) |
| - embeds = ( |
| - self.pos_embed(embeds) |
| - .permute(1, 0) |
| - .reshape(1, -1, self.num_grid_per_side, self.num_grid_per_side) |
| + for i in range(4): |
| + idx_list[i].extend(indices[i].tolist()) |
| + weight_list[i].extend(weights[i].tolist()) |
| + |
| + idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) |
| + weight_tensor = torch.tensor( |
| + weight_list, dtype=self.pos_embed.weight.dtype, device=device |
| ) |
| - for t, h, w in grid_thw: |
| - pos_embed = torch.nn.functional.interpolate( |
| - embeds, size=(h, w), mode="bilinear", align_corners=self.align_corners |
| - ) |
| - pos_embed = pos_embed.reshape( |
| - -1, |
| - h // self.spatial_merge_size, |
| - self.spatial_merge_size, |
| - w // self.spatial_merge_size, |
| - self.spatial_merge_size, |
| + pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None] |
| + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] |
| + |
| + patch_pos_embeds = patch_pos_embeds.split( |
| + [h * w for h, w in zip(grid_hs, grid_ws)] |
| + ) |
| + |
| + patch_pos_embeds_permute = [] |
| + merge_size = self.spatial_merge_size |
| + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): |
| + pos_embed = pos_embed.repeat(t, 1) |
| + pos_embed = ( |
| + pos_embed.view( |
| + t, h // merge_size, merge_size, w // merge_size, merge_size, -1 |
| + ) |
| + .permute(0, 1, 3, 2, 4, 5) |
| + .flatten(0, 4) |
| ) |
| - pos_embed = pos_embed.permute(1, 3, 2, 4, 0) |
| - pos_embed = pos_embed.flatten(0, 3).repeat(t, 1) |
| patch_pos_embeds_permute.append(pos_embed) |
| return torch.cat(patch_pos_embeds_permute) |
| |
| @@ -610,14 +650,19 @@ class Qwen3LLMModel(Qwen3Model): |
| hidden_states + residual if residual is not None else hidden_states |
| ) |
| |
| + deepstack_embeds = None |
| + if input_deepstack_embeds is not None: |
| + prev_layer_idx = layer_idx - 1 |
| + if prev_layer_idx in self.deepstack_embed_to_decoder_layer: |
| + sep = self.hidden_size * prev_layer_idx |
| + deepstack_embeds = input_deepstack_embeds[ |
| + :, sep : sep + self.hidden_size |
| + ] |
| + |
| # SGLang applies residual at the START of the next layer, not at the END like HuggingFace. |
| # See: https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py#L549 |
| # To match HF behavior, deepstack must be added AFTER residual: (hidden_states + residual) + deepstack |
| # The order matters because addition with different tensors is not associative in practice. |
| - # Deepstack for prev_layer is applied at the start of current layer via post_residual_addition. |
| - deepstack_embeds = self.get_deepstack_embeds( |
| - layer_idx - 1, input_deepstack_embeds |
| - ) |
| hidden_states, residual = layer( |
| positions, |
| hidden_states, |
| |
| |
| |
| |
| @@ -527,6 +527,7 @@ class ServerArgs: |
| cuda_graph_max_bs: Optional[int] = None |
| cuda_graph_bs: Optional[List[int]] = None |
| disable_cuda_graph: bool = False |
| + disable_draft_cuda_graph: bool = False |
| disable_cuda_graph_padding: bool = False |
| enable_profile_cuda_graph: bool = False |
| enable_cudagraph_gc: bool = False |
| @@ -3980,6 +3981,11 @@ class ServerArgs: |
| action="store_true", |
| help="Disable cuda graph.", |
| ) |
| + parser.add_argument( |
| + "--disable-draft-cuda-graph", |
| + action="store_true", |
| + help="Disable cuda graph for draft model in speculative decoding.", |
| + ) |
| parser.add_argument( |
| "--disable-cuda-graph-padding", |
| action="store_true", |
| |
| |
| |
| |
| @@ -341,7 +341,10 @@ class EAGLEDraftCudaGraphRunner: |
| self.seq_lens.fill_(self.seq_len_fill_value) |
| self.out_cache_loc.zero_() |
| self.positions.zero_() |
| - |
| + self.topk_p.zero_() |
| + self.topk_index.zero_() |
| + self.hidden_states.zero_() |
| + self.req_pool_indices.zero_() |
| num_tokens = bs * self.num_tokens_per_bs |
| |
| # Common inputs |
| @@ -350,8 +353,8 @@ class EAGLEDraftCudaGraphRunner: |
| forward_batch.out_cache_loc |
| ) |
| self.positions[:raw_num_token].copy_(forward_batch.positions) |
| - self.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p) |
| - self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index) |
| + self.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p.clamp(0, 1)) |
| + self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index.clamp(0, self.model_runner.model_config.vocab_size - 1)) |
| self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states) |
| self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) |
| |
| |
| |
| |
| |
| @@ -778,6 +778,10 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): |
| self.topk_index = self.topk_index[: len(new_indices)] |
| self.hidden_states = self.hidden_states[: len(new_indices)] |
| self.verified_id = self.verified_id[: len(new_indices)] |
| + if self.accept_length is not None: |
| + self.accept_length = self.accept_length[: len(new_indices)] |
| + if self.accept_length_cpu is not None: |
| + self.accept_length_cpu = self.accept_length_cpu[: len(new_indices)] |
| else: |
| # in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index` |
| self.topk_p = self.topk_p[new_indices] |
| @@ -809,6 +813,27 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): |
| self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0) |
| self.topk_p = torch.cat([self.topk_p, spec_info.topk_p]) |
| self.topk_index = torch.cat([self.topk_index, spec_info.topk_index]) |
| + if self.accept_length is not None and spec_info.accept_length is not None: |
| + self.accept_length = torch.cat( |
| + [self.accept_length, spec_info.accept_length] |
| + ) |
| + self.accept_length_cpu = self.accept_length.tolist() |
| + elif self.accept_length is not None: |
| + zeros = torch.zeros( |
| + [spec_info.verified_id.shape[0]], |
| + dtype=self.accept_length.dtype, |
| + device=self.accept_length.device, |
| + ) |
| + self.accept_length = torch.cat([self.accept_length, zeros]) |
| + self.accept_length_cpu = self.accept_length.tolist() |
| + elif spec_info.accept_length is not None: |
| + zeros = torch.zeros( |
| + [self.verified_id.shape[0]], |
| + dtype=self.accept_length.dtype, |
| + device=self.accept_length.device, |
| + ) |
| + self.accept_length = torch.cat([zeros, spec_info.accept_length]) |
| + self.accept_length_cpu = self.accept_length.tolist() |
| |
| |
| @dataclass |
| |
| |
| |
| |
| @@ -231,7 +231,7 @@ class EAGLEWorker(TpModelWorker): |
| self.cuda_graph_runner = None |
| self.cuda_graph_runner_for_draft_extend = None |
| |
| - if self.server_args.disable_cuda_graph: |
| + if self.server_args.disable_cuda_graph or self.server_args.disable_draft_cuda_graph: |
| return |
| |
| Device2DraftCudaGraphRunner = { |
| |
| |
| |
| |
| @@ -2224,6 +2224,8 @@ class SafeUnpickler(pickle.Unpickler): |
| "sglang.srt.model_executor.model_runner.", |
| "sglang.srt.layers.", |
| "sglang.srt.utils.", |
| + # --- slime --- |
| + "slime.", |
| } |
| |
| DENY_CLASSES = { |
|
|