| # Copyright 2023-2024 SGLang Team | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """ModelRunner runs the forward passes of the models.""" | |
| import datetime | |
| import gc | |
| import inspect | |
| import json | |
| import logging | |
| import os | |
| import socket | |
| import threading | |
| import time | |
| from collections import defaultdict | |
| from dataclasses import dataclass | |
| from typing import Callable, List, Optional, Tuple, Union | |
| import torch | |
| import torch.distributed as dist | |
| from sglang.srt.configs import FalconH1Config, NemotronHConfig, Qwen3NextConfig | |
| from sglang.srt.configs.device_config import DeviceConfig | |
| from sglang.srt.configs.load_config import LoadConfig, LoadFormat | |
| from sglang.srt.configs.model_config import ( | |
| AttentionArch, | |
| ModelConfig, | |
| get_nsa_index_head_dim, | |
| is_deepseek_nsa, | |
| ) | |
| from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp | |
| from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS | |
| from sglang.srt.distributed import ( | |
| get_pp_group, | |
| get_tp_group, | |
| get_world_group, | |
| init_distributed_environment, | |
| initialize_model_parallel, | |
| set_custom_all_reduce, | |
| set_mscclpp_all_reduce, | |
| set_symm_mem_all_reduce, | |
| ) | |
| from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state | |
| from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager | |
| from sglang.srt.eplb.eplb_manager import EPLBManager | |
| from sglang.srt.eplb.expert_distribution import ( | |
| ExpertDistributionRecorder, | |
| get_global_expert_distribution_recorder, | |
| set_global_expert_distribution_recorder, | |
| ) | |
| from sglang.srt.eplb.expert_location import ( | |
| ExpertLocationMetadata, | |
| compute_initial_expert_location_metadata, | |
| get_global_expert_location_metadata, | |
| set_global_expert_location_metadata, | |
| ) | |
| from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater | |
| from sglang.srt.layers import deep_gemm_wrapper | |
| from sglang.srt.layers.attention.attention_registry import ( | |
| ATTENTION_BACKENDS, | |
| attn_backend_wrapper, | |
| ) | |
| from sglang.srt.layers.attention.tbo_backend import TboAttnBackend | |
| from sglang.srt.layers.dp_attention import ( | |
| get_attention_tp_group, | |
| get_attention_tp_size, | |
| initialize_dp_attention, | |
| ) | |
| from sglang.srt.layers.logits_processor import LogitsProcessorOutput | |
| from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer | |
| from sglang.srt.layers.sampler import Sampler | |
| from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model | |
| from sglang.srt.lora.lora_manager import LoRAManager | |
| from sglang.srt.lora.lora_registry import LoRARef | |
| from sglang.srt.mem_cache.allocator import ( | |
| BaseTokenToKVPoolAllocator, | |
| PagedTokenToKVPoolAllocator, | |
| SWATokenToKVPoolAllocator, | |
| TokenToKVPoolAllocator, | |
| ) | |
| from sglang.srt.mem_cache.allocator_ascend import AscendPagedTokenToKVPoolAllocator | |
| from sglang.srt.mem_cache.memory_pool import ( | |
| AscendMLAPagedTokenToKVPool, | |
| AscendTokenToKVPool, | |
| DoubleSparseTokenToKVPool, | |
| HybridLinearKVPool, | |
| HybridReqToTokenPool, | |
| MHATokenToKVPool, | |
| MLATokenToKVPool, | |
| NSATokenToKVPool, | |
| ReqToTokenPool, | |
| SWAKVPool, | |
| ) | |
| from sglang.srt.model_executor.cpu_graph_runner import CPUGraphRunner | |
| from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner | |
| from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors | |
| from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner | |
| from sglang.srt.model_executor.piecewise_cuda_graph_runner import ( | |
| PiecewiseCudaGraphRunner, | |
| ) | |
| from sglang.srt.model_loader import get_model | |
| from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader | |
| from sglang.srt.model_loader.remote_instance_weight_loader_utils import ( | |
| trigger_init_weights_send_group_for_remote_instance_request, | |
| ) | |
| from sglang.srt.model_loader.utils import set_default_torch_dtype | |
| from sglang.srt.model_loader.weight_utils import default_weight_loader | |
| from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo | |
| from sglang.srt.server_args import ( | |
| ServerArgs, | |
| get_global_server_args, | |
| set_global_server_args_for_scheduler, | |
| ) | |
| from sglang.srt.speculative.spec_info import SpeculativeAlgorithm | |
| from sglang.srt.utils import ( | |
| MultiprocessingSerializer, | |
| cpu_has_amx_support, | |
| dynamic_import, | |
| enable_show_time_cost, | |
| get_available_gpu_memory, | |
| get_bool_env_var, | |
| get_cpu_ids_by_node, | |
| init_custom_process_group, | |
| is_fa3_default_architecture, | |
| is_flashinfer_available, | |
| is_hip, | |
| is_hopper_with_cuda_12_3, | |
| is_no_spec_infer_or_topk_one, | |
| is_npu, | |
| is_sm100_supported, | |
| log_info_on_rank0, | |
| monkey_patch_p2p_access_check, | |
| monkey_patch_vllm_gguf_config, | |
| set_cuda_arch, | |
| slow_rank_detector, | |
| xpu_has_xmx_support, | |
| ) | |
| from sglang.srt.utils.offloader import ( | |
| create_offloader_from_server_args, | |
| get_offloader, | |
| set_offloader, | |
| ) | |
| from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions | |
| from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter | |
| from sglang.srt.weight_sync.tensor_bucket import ( | |
| FlattenedTensorBucket, | |
| FlattenedTensorMetadata, | |
| ) | |
| MLA_ATTENTION_BACKENDS = [ | |
| "aiter", | |
| "flashinfer", | |
| "fa3", | |
| "fa4", | |
| "triton", | |
| "flashmla", | |
| "cutlass_mla", | |
| "trtllm_mla", | |
| "ascend", | |
| "nsa", | |
| ] | |
| CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS = [ | |
| "flashinfer", | |
| "fa3", | |
| "fa4", | |
| "flashmla", | |
| "cutlass_mla", | |
| "trtllm_mla", | |
| ] | |
| def add_mla_attention_backend(backend_name): | |
| if backend_name not in MLA_ATTENTION_BACKENDS: | |
| MLA_ATTENTION_BACKENDS.append(backend_name) | |
| logger.info(f"Added {backend_name} to MLA_ATTENTION_BACKENDS.") | |
| def add_chunked_prefix_cache_attention_backend(backend_name): | |
| if backend_name not in CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS: | |
| CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS.append(backend_name) | |
| logger.info( | |
| f"Added {backend_name} to CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS." | |
| ) | |
| _is_hip = is_hip() | |
| _is_npu = is_npu() | |
| _is_cpu_amx_available = cpu_has_amx_support() | |
| _is_xpu_xmx_available = xpu_has_xmx_support() | |
| # Use a small KV cache pool size for tests in CI | |
| SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None) | |
| # Detect stragger ranks in model loading | |
| UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300 | |
| # the ratio of mamba cache pool size to max_running_requests, it will be safe when it is larger than 2 (yizhang2077) | |
| MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO = 3 | |
| logger = logging.getLogger(__name__) | |
| if _is_npu: | |
| import torch_npu | |
| torch.npu.config.allow_internal_format = True | |
| torch_npu.npu.set_compile_mode(jit_compile=False) | |
| class RankZeroFilter(logging.Filter): | |
| """Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank.""" | |
| def __init__(self, is_rank_zero): | |
| super().__init__() | |
| self.is_rank_zero = is_rank_zero | |
| def filter(self, record): | |
| if record.levelno == logging.INFO: | |
| return self.is_rank_zero | |
| return True | |
| class ModelRunner: | |
| """ModelRunner runs the forward passes of the models.""" | |
| def __init__( | |
| self, | |
| model_config: ModelConfig, | |
| mem_fraction_static: float, | |
| gpu_id: int, | |
| tp_rank: int, | |
| tp_size: int, | |
| moe_ep_rank: int, | |
| moe_ep_size: int, | |
| pp_rank: int, | |
| pp_size: int, | |
| nccl_port: int, | |
| server_args: ServerArgs, | |
| dp_rank: Optional[int] = None, | |
| is_draft_worker: bool = False, | |
| req_to_token_pool: Optional[ReqToTokenPool] = None, | |
| token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None, | |
| ): | |
| # Parse args | |
| self.mem_fraction_static = mem_fraction_static | |
| self.device = server_args.device | |
| self.gpu_id = gpu_id | |
| self.tp_rank = tp_rank | |
| self.tp_size = tp_size | |
| self.moe_ep_rank = moe_ep_rank | |
| self.moe_ep_size = moe_ep_size | |
| self.dp_size = server_args.dp_size | |
| self.pp_rank = pp_rank | |
| self.pp_size = pp_size | |
| self.model_config = model_config | |
| self.dist_port = nccl_port | |
| self.server_args = server_args | |
| self.is_draft_worker = is_draft_worker | |
| self.is_generation = model_config.is_generation | |
| self.is_multimodal = model_config.is_multimodal | |
| self.is_multimodal_chunked_prefill_supported = ( | |
| model_config.is_multimodal_chunked_prefill_supported | |
| ) | |
| self.spec_algorithm = SpeculativeAlgorithm.from_string( | |
| server_args.speculative_algorithm | |
| ) | |
| self.page_size = server_args.page_size | |
| self.req_to_token_pool = req_to_token_pool | |
| self.token_to_kv_pool_allocator = token_to_kv_pool_allocator | |
| self.is_hybrid = model_config.is_hybrid | |
| self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA | |
| self.attention_chunk_size = model_config.attention_chunk_size | |
| self.forward_pass_id = 0 | |
| self.init_new_workspace = False | |
| # Apply the rank zero filter to logger | |
| if server_args.show_time_cost: | |
| enable_show_time_cost() | |
| # Model-specific adjustment | |
| self.model_specific_adjustment() | |
| # Set the global server_args in the scheduler process | |
| set_global_server_args_for_scheduler(server_args) | |
| global_server_args = get_global_server_args() | |
| # FIXME: hacky set `use_mla_backend` | |
| global_server_args.use_mla_backend = self.use_mla_backend | |
| # Init OpenMP threads binding for CPU | |
| if self.device == "cpu": | |
| self.init_threads_binding() | |
| # Get memory before model loading | |
| min_per_gpu_memory = self.init_torch_distributed() | |
| # CPU offload | |
| set_offloader(create_offloader_from_server_args(server_args, dp_rank=dp_rank)) | |
| if get_bool_env_var("SGLANG_DETECT_SLOW_RANK"): | |
| slow_rank_detector.execute() | |
| # Update deep gemm configure | |
| if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM: | |
| deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args) | |
| # Initialize the model runner | |
| self.initialize(min_per_gpu_memory) | |
| # Temporary cached values | |
| self.support_pp = ( | |
| "pp_proxy_tensors" in inspect.signature(self.model.forward).parameters | |
| ) | |
| # For weight updates | |
| self._model_update_group = {} | |
| self._weights_send_group = {} | |
| if ( | |
| self.server_args.enable_piecewise_cuda_graph | |
| and self.can_run_piecewise_cuda_graph() | |
| ): | |
| self.attention_layers = [] | |
| for layer in self.model.model.layers: | |
| if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "attn"): | |
| self.attention_layers.append(layer.self_attn.attn) | |
| if len(self.attention_layers) < self.model_config.num_hidden_layers: | |
| # TODO(yuwei): support Non-Standard GQA | |
| log_info_on_rank0( | |
| logger, | |
| "Disable piecewise CUDA graph because some layers do not apply Standard GQA", | |
| ) | |
| self.piecewise_cuda_graph_runner = None | |
| else: | |
| self.piecewise_cuda_graph_runner = PiecewiseCudaGraphRunner(self) | |
| else: | |
| self.piecewise_cuda_graph_runner = None | |
| def initialize(self, min_per_gpu_memory: float): | |
| server_args = self.server_args | |
| self.memory_saver_adapter = TorchMemorySaverAdapter.create( | |
| enable=self.server_args.enable_memory_saver | |
| ) | |
| if not self.is_draft_worker: | |
| set_global_expert_location_metadata( | |
| compute_initial_expert_location_metadata(server_args, self.model_config) | |
| ) | |
| if self.tp_rank == 0 and get_bool_env_var( | |
| "SGLANG_LOG_EXPERT_LOCATION_METADATA" | |
| ): | |
| logger.info( | |
| f"Initial expert_location_metadata: {get_global_expert_location_metadata()}" | |
| ) | |
| set_global_expert_distribution_recorder( | |
| ExpertDistributionRecorder.init_new( | |
| server_args, | |
| get_global_expert_location_metadata(), | |
| rank=self.tp_rank, | |
| ) | |
| ) | |
| # Expert parallelism | |
| self.eplb_manager = ( | |
| EPLBManager(self) | |
| if self.server_args.enable_eplb and (not self.is_draft_worker) | |
| else None | |
| ) | |
| self.expert_location_updater = ExpertLocationUpdater() | |
| ( | |
| ElasticEPStateManager.init(self.server_args) | |
| if self.server_args.elastic_ep_backend | |
| else None | |
| ) | |
| # Load the model | |
| self.sampler = Sampler() | |
| self.load_model() | |
| # Check if the model is using hybrid SWA | |
| if ( | |
| not self.server_args.disable_hybrid_swa_memory | |
| and self.sliding_window_size is not None | |
| and self.sliding_window_size > 0 | |
| ): | |
| architectures = self.model_config.hf_config.architectures | |
| if architectures and not any("Llama4" in arch for arch in architectures): | |
| self.is_hybrid = self.model_config.is_hybrid = True | |
| if config := self.mamba2_config: | |
| class_name = config.__class__.__name__ | |
| logger.warning(f"{class_name} model detected, disable radix cache") | |
| self.server_args.disable_radix_cache = True | |
| # For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft | |
| # models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to | |
| # determine the number of layers. | |
| model_has_mtp_layers = self.model_config.num_nextn_predict_layers is not None | |
| model_num_layers = ( | |
| self.model_config.num_nextn_predict_layers | |
| if self.is_draft_worker and model_has_mtp_layers | |
| else max( | |
| self.model_config.num_hidden_layers, | |
| self.model_config.num_attention_layers, | |
| ) | |
| ) | |
| self.start_layer = getattr(self.model, "start_layer", 0) | |
| self.end_layer = getattr(self.model, "end_layer", model_num_layers) | |
| self.num_effective_layers = self.end_layer - self.start_layer | |
| assert ( | |
| (not model_has_mtp_layers) | |
| or (self.spec_algorithm.is_none()) | |
| or ( | |
| (not self.spec_algorithm.is_none()) | |
| and (self.num_effective_layers == model_num_layers) | |
| ) | |
| ), "PP is not compatible with MTP models." | |
| # Apply torchao quantization | |
| torchao_applied = getattr(self.model, "torchao_applied", False) | |
| # In layered loading, torchao may have been applied | |
| if not torchao_applied: | |
| apply_torchao_config_to_model( | |
| self.model, get_global_server_args().torchao_config | |
| ) | |
| # Apply torch TP if the model supports it | |
| supports_torch_tp = getattr(self.model, "supports_torch_tp", False) | |
| if self.tp_size > 1 and supports_torch_tp: | |
| self.apply_torch_tp() | |
| # Init lora | |
| if server_args.enable_lora: | |
| self.init_lora_manager() | |
| # Init Double Sparsity | |
| if server_args.enable_double_sparsity: | |
| if server_args.ds_heavy_channel_type is None: | |
| raise ValueError( | |
| "Please specify the heavy channel type for double sparsity optimization." | |
| ) | |
| self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type) | |
| # Enable batch invariant mode | |
| if server_args.enable_deterministic_inference: | |
| from sglang.srt.batch_invariant_ops import enable_batch_invariant_mode | |
| enable_batch_invariant_mode() | |
| # Init memory pool and attention backends | |
| self.init_memory_pool( | |
| min_per_gpu_memory, | |
| server_args.max_running_requests, | |
| server_args.max_total_tokens, | |
| ) | |
| if self.device == "cuda": | |
| self.init_cublas() | |
| self.init_attention_backend() | |
| self.init_device_graphs() | |
| elif self.device in ["npu", "cpu"]: | |
| self.init_attention_backend() | |
| self.init_device_graphs() | |
| else: | |
| self.graph_runner = None | |
| self.graph_mem_usage = 0 | |
| self.init_attention_backend() | |
| # auxiliary hidden capture mode. TODO: expose this to server args? | |
| if self.spec_algorithm.is_eagle3() and not self.is_draft_worker: | |
| # load draft config | |
| draft_model_config = ModelConfig.from_server_args( | |
| server_args, | |
| model_path=(server_args.speculative_draft_model_path), | |
| is_draft_model=True, | |
| ) | |
| try: | |
| # get the aux layer from draft model config | |
| eagle_config = getattr( | |
| draft_model_config.hf_config, "eagle_config", None | |
| ) | |
| eagle_aux_hidden_state_layer_ids = eagle_config[ | |
| "eagle_aux_hidden_state_layer_ids" | |
| ] | |
| except: | |
| # if there is no aux layer, set to None | |
| eagle_aux_hidden_state_layer_ids = None | |
| self.model.set_eagle3_layers_to_capture(eagle_aux_hidden_state_layer_ids) | |
| def model_specific_adjustment(self): | |
| server_args = self.server_args | |
| if ( | |
| server_args.attention_backend == "intel_amx" | |
| and server_args.device == "cpu" | |
| and not _is_cpu_amx_available | |
| ): | |
| logger.info( | |
| "The current platform does not support Intel AMX, will fallback to torch_native backend." | |
| ) | |
| server_args.attention_backend = "torch_native" | |
| if ( | |
| server_args.attention_backend == "intel_xpu" | |
| and server_args.device == "xpu" | |
| and not _is_xpu_xmx_available | |
| ): | |
| logger.info( | |
| "The current platform does not support Intel XMX, will fallback to triton backend." | |
| ) | |
| server_args.attention_backend = "triton" | |
| if server_args.prefill_attention_backend is not None and ( | |
| server_args.prefill_attention_backend | |
| == server_args.decode_attention_backend | |
| ): # override the default attention backend | |
| server_args.attention_backend = server_args.prefill_attention_backend | |
| if ( | |
| getattr(self.model_config.hf_config, "dual_chunk_attention_config", None) | |
| is not None | |
| ): | |
| if server_args.attention_backend is None: | |
| server_args.attention_backend = "dual_chunk_flash_attn" | |
| logger.info("Dual chunk attention is turned on by default.") | |
| elif server_args.attention_backend != "dual_chunk_flash_attn": | |
| raise ValueError( | |
| "Dual chunk attention is enabled, but attention backend is set to " | |
| f"{server_args.attention_backend}. Please set it to 'dual_chunk_flash_attn'." | |
| ) | |
| if server_args.attention_backend is None: | |
| """ | |
| Auto select the fastest attention backend. | |
| 1. Models with MHA Architecture (e.g: Llama, QWen) | |
| 1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1. | |
| 1.2 In other cases, we will use flashinfer if available, otherwise use triton. | |
| 2. Models with MLA Architecture and using FA3 | |
| 2.1 We will use FA3 backend on hopper. | |
| 2.2 We will use Flashinfer backend on blackwell. | |
| 2.3 Otherwise, we will use triton backend. | |
| """ | |
| if not self.use_mla_backend: | |
| # MHA architecture | |
| if ( | |
| is_hopper_with_cuda_12_3() | |
| and is_no_spec_infer_or_topk_one(server_args) | |
| and is_fa3_default_architecture(self.model_config.hf_config) | |
| ): | |
| server_args.attention_backend = "fa3" | |
| elif _is_hip: | |
| server_args.attention_backend = "aiter" | |
| elif _is_npu: | |
| server_args.attention_backend = "ascend" | |
| else: | |
| server_args.attention_backend = ( | |
| "flashinfer" if is_flashinfer_available() else "triton" | |
| ) | |
| else: | |
| # MLA architecture | |
| if is_hopper_with_cuda_12_3(): | |
| server_args.attention_backend = "fa3" | |
| elif is_sm100_supported(): | |
| server_args.attention_backend = "flashinfer" | |
| elif _is_hip: | |
| head_num = self.model_config.get_num_kv_heads(self.tp_size) | |
| # TODO current aiter only support head number 16 or 128 head number | |
| if head_num == 128 or head_num == 16: | |
| server_args.attention_backend = "aiter" | |
| else: | |
| server_args.attention_backend = "triton" | |
| elif _is_npu: | |
| server_args.attention_backend = "ascend" | |
| else: | |
| server_args.attention_backend = "triton" | |
| log_info_on_rank0( | |
| logger, | |
| f"Attention backend not explicitly specified. Use {server_args.attention_backend} backend by default.", | |
| ) | |
| elif self.use_mla_backend: | |
| if server_args.device != "cpu": | |
| if server_args.attention_backend in MLA_ATTENTION_BACKENDS: | |
| logger.info( | |
| f"MLA optimization is turned on. Use {server_args.attention_backend} backend." | |
| ) | |
| else: | |
| raise ValueError( | |
| f"Invalid attention backend for MLA: {server_args.attention_backend}" | |
| ) | |
| else: | |
| if server_args.attention_backend != "intel_amx": | |
| raise ValueError( | |
| "MLA optimization not supported on CPU except for intel_amx backend." | |
| ) | |
| if ( | |
| server_args.attention_backend == "fa3" | |
| and server_args.kv_cache_dtype == "fp8_e5m2" | |
| ): | |
| logger.warning( | |
| "FlashAttention3 only supports fp8_e4m3 if using FP8; " | |
| "Setting attention backend to triton." | |
| ) | |
| server_args.attention_backend = "triton" | |
| if server_args.enable_double_sparsity: | |
| logger.info( | |
| "Double sparsity optimization is turned on. Use triton backend without CUDA graph." | |
| ) | |
| server_args.attention_backend = "triton" | |
| server_args.disable_cuda_graph = True | |
| if self.is_multimodal: | |
| if not self.is_multimodal_chunked_prefill_supported: | |
| server_args.chunked_prefill_size = -1 | |
| logger.info( | |
| f"Automatically turn off --chunked-prefill-size as it is not supported for " | |
| f"{self.model_config.hf_config.model_type}" | |
| ) | |
| if ( | |
| not self.use_mla_backend | |
| or server_args.attention_backend | |
| not in CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS | |
| ): | |
| server_args.disable_chunked_prefix_cache = True | |
| if not server_args.disable_chunked_prefix_cache: | |
| log_info_on_rank0(logger, "Chunked prefix cache is turned on.") | |
| if server_args.attention_backend == "aiter": | |
| if self.model_config.context_len > 8192: | |
| self.mem_fraction_static *= 0.85 | |
| if ( | |
| server_args.enable_hierarchical_cache | |
| and server_args.hicache_io_backend == "kernel" | |
| ): | |
| # fix for the compatibility issue with FlashAttention3 decoding and HiCache kernel backend | |
| if server_args.decode_attention_backend is None: | |
| if not self.use_mla_backend: | |
| server_args.decode_attention_backend = ( | |
| "flashinfer" if is_flashinfer_available() else "triton" | |
| ) | |
| else: | |
| server_args.decode_attention_backend = ( | |
| "flashinfer" if is_sm100_supported() else "triton" | |
| ) | |
| elif server_args.decode_attention_backend == "fa3": | |
| server_args.hicache_io_backend = "direct" | |
| logger.warning( | |
| "FlashAttention3 decode backend is not compatible with hierarchical cache. " | |
| "Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes." | |
| ) | |
| if self.model_config.hf_config.model_type == "qwen3_vl_moe": | |
| if ( | |
| quantization_config := getattr( | |
| self.model_config.hf_config, "quantization_config", None | |
| ) | |
| ) is not None: | |
| weight_block_size_n = quantization_config["weight_block_size"][0] | |
| if self.tp_size % self.moe_ep_size != 0: | |
| raise ValueError( | |
| f"tp_size {self.tp_size} must be divisible by moe_ep_size {self.moe_ep_size}" | |
| ) | |
| moe_tp_size = self.tp_size // self.moe_ep_size | |
| moe_intermediate_size = ( | |
| self.model_config.hf_text_config.moe_intermediate_size | |
| ) | |
| if moe_intermediate_size % moe_tp_size != 0: | |
| raise ValueError( | |
| f"moe_intermediate_size {moe_intermediate_size} must be divisible by moe_tp_size ({moe_tp_size}) which is tp_size ({self.tp_size}) divided by moe_ep_size ({self.moe_ep_size})." | |
| ) | |
| if (moe_intermediate_size // moe_tp_size) % weight_block_size_n != 0: | |
| raise ValueError( | |
| f"For qwen3-vl-fp8 models, please make sure ({moe_intermediate_size=} / {moe_tp_size=}) % {weight_block_size_n=} == 0 " | |
| f"where moe_tp_size is equal to tp_size ({self.tp_size}) divided by moe_ep_size ({self.moe_ep_size}). " | |
| f"You can fix this by setting arguments `--tp-size` and `--ep-size` correctly." | |
| ) | |
| def init_torch_distributed(self): | |
| logger.info("Init torch distributed begin.") | |
| try: | |
| torch.get_device_module(self.device).set_device(self.gpu_id) | |
| except Exception: | |
| logger.warning( | |
| f"Context: {self.device=} {self.gpu_id=} {os.environ.get('CUDA_VISIBLE_DEVICES')=} {self.tp_rank=} {self.tp_size=}" | |
| ) | |
| raise | |
| if self.device == "cuda": | |
| if self.server_args.elastic_ep_backend == "mooncake": | |
| backend = "mooncake" | |
| if self.server_args.mooncake_ib_device: | |
| mooncake_ib_device = self.server_args.mooncake_ib_device.split(",") | |
| try: | |
| from mooncake import ep as mooncake_ep | |
| mooncake_ep.set_device_filter(mooncake_ib_device) | |
| except: | |
| pass # A warning will be raised in `init_distributed_environment` | |
| else: | |
| backend = "nccl" | |
| elif self.device == "xpu": | |
| backend = "xccl" | |
| elif self.device == "hpu": | |
| backend = "hccl" | |
| elif self.device == "cpu": | |
| backend = "gloo" | |
| elif self.device == "npu": | |
| backend = "hccl" | |
| before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id) | |
| if not self.server_args.enable_p2p_check: | |
| monkey_patch_p2p_access_check() | |
| if self.server_args.dist_init_addr: | |
| dist_init_method = f"tcp://{self.server_args.dist_init_addr}" | |
| else: | |
| dist_init_method = f"tcp://127.0.0.1:{self.dist_port}" | |
| set_custom_all_reduce(not self.server_args.disable_custom_all_reduce) | |
| set_mscclpp_all_reduce(self.server_args.enable_mscclpp) | |
| set_symm_mem_all_reduce(self.server_args.enable_torch_symm_mem) | |
| if not self.is_draft_worker: | |
| if self.device == "cpu": | |
| if _is_cpu_amx_available: | |
| # Bind OpenMP threads to CPU cores | |
| torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid) | |
| # Set local size to hint SGLang to use shared memory based AllReduce | |
| os.environ["LOCAL_SIZE"] = str(self.tp_size) | |
| torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank) | |
| def _(data, dim): | |
| return torch.cat([data] * self.tp_size, dim=dim) | |
| else: | |
| logger.warning( | |
| "init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available" | |
| ) | |
| # Only initialize the distributed environment on the target model worker. | |
| init_distributed_environment( | |
| backend=backend, | |
| world_size=self.tp_size * self.pp_size, | |
| rank=self.tp_size * self.pp_rank + self.tp_rank, | |
| local_rank=self.gpu_id, | |
| distributed_init_method=dist_init_method, | |
| timeout=self.server_args.dist_timeout, | |
| ) | |
| initialize_model_parallel( | |
| tensor_model_parallel_size=self.tp_size, | |
| pipeline_model_parallel_size=self.pp_size, | |
| expert_model_parallel_size=self.moe_ep_size, | |
| duplicate_tp_group=self.server_args.enable_pdmux, | |
| torch_compile=self.server_args.enable_piecewise_cuda_graph, | |
| ) | |
| initialize_dp_attention( | |
| server_args=self.server_args, | |
| model_config=self.model_config, | |
| ) | |
| min_per_gpu_memory = get_available_gpu_memory( | |
| self.device, | |
| self.gpu_id, | |
| distributed=get_world_group().world_size > 1, | |
| cpu_group=get_world_group().cpu_group, | |
| ) | |
| self.tp_group = get_tp_group() | |
| self.pp_group = get_pp_group() | |
| self.attention_tp_group = get_attention_tp_group() | |
| # Check memory for tensor parallelism | |
| local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id) | |
| if self.tp_size > 1 and not self.is_draft_worker: | |
| if min_per_gpu_memory < local_gpu_memory * 0.9: | |
| if get_bool_env_var("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK"): | |
| logger.warning( | |
| "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. " | |
| f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}" | |
| ) | |
| else: | |
| raise ValueError( | |
| "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. " | |
| f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}" | |
| ) | |
| logger.info( | |
| f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB" | |
| ) | |
| return min_per_gpu_memory | |
| def load_model(self): | |
| before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id) | |
| logger.info( | |
| f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" | |
| ) | |
| # This can reduce thread conflicts and speed up weight loading. | |
| if self.device != "cpu": | |
| torch.set_num_threads(1) | |
| if self.device == "cuda": | |
| if torch.cuda.get_device_capability()[0] < 8: | |
| logger.info( | |
| "Compute capability below sm80. Use float16 due to lack of bfloat16 support." | |
| ) | |
| self.server_args.dtype = "float16" | |
| self.model_config.dtype = torch.float16 | |
| if torch.cuda.get_device_capability()[1] < 5: | |
| raise RuntimeError("SGLang only supports sm75 and above.") | |
| set_cuda_arch() | |
| # Prepare the model config | |
| from sglang.srt.configs.modelopt_config import ModelOptConfig | |
| modelopt_config = ModelOptConfig( | |
| quant=self.server_args.modelopt_quant, | |
| checkpoint_restore_path=self.server_args.modelopt_checkpoint_restore_path, | |
| checkpoint_save_path=self.server_args.modelopt_checkpoint_save_path, | |
| export_path=self.server_args.modelopt_export_path, | |
| quantize_and_serve=self.server_args.quantize_and_serve, | |
| ) | |
| self.load_config = LoadConfig( | |
| load_format=self.server_args.load_format, | |
| download_dir=self.server_args.download_dir, | |
| model_loader_extra_config=self.server_args.model_loader_extra_config, | |
| tp_rank=self.tp_rank, | |
| remote_instance_weight_loader_seed_instance_ip=self.server_args.remote_instance_weight_loader_seed_instance_ip, | |
| remote_instance_weight_loader_seed_instance_service_port=self.server_args.remote_instance_weight_loader_seed_instance_service_port, | |
| remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports, | |
| modelopt_config=modelopt_config, | |
| ) | |
| if self.device == "cpu": | |
| self.model_config = adjust_config_with_unaligned_cpu_tp( | |
| self.model_config, self.load_config, self.tp_size | |
| ) | |
| if self.server_args.load_format == "gguf": | |
| monkey_patch_vllm_gguf_config() | |
| if self.server_args.load_format == LoadFormat.REMOTE_INSTANCE: | |
| if self.tp_rank == 0: | |
| instance_ip = socket.gethostbyname(socket.gethostname()) | |
| t = threading.Thread( | |
| target=trigger_init_weights_send_group_for_remote_instance_request, | |
| args=( | |
| self.server_args.remote_instance_weight_loader_seed_instance_ip, | |
| self.server_args.remote_instance_weight_loader_seed_instance_service_port, | |
| self.server_args.remote_instance_weight_loader_send_weights_group_ports, | |
| instance_ip, | |
| ), | |
| ) | |
| t.start() | |
| # Load the model | |
| # Remove monkey_patch when linear.py quant remove dependencies with vllm | |
| monkey_patch_vllm_parallel_state() | |
| monkey_patch_isinstance_for_vllm_base_layer() | |
| with self.memory_saver_adapter.region( | |
| GPU_MEMORY_TYPE_WEIGHTS, | |
| enable_cpu_backup=self.server_args.enable_weights_cpu_backup, | |
| ): | |
| self.model = get_model( | |
| model_config=self.model_config, | |
| load_config=self.load_config, | |
| device_config=DeviceConfig(self.device, self.gpu_id), | |
| ) | |
| monkey_patch_vllm_parallel_state(reverse=True) | |
| monkey_patch_isinstance_for_vllm_base_layer(reverse=True) | |
| get_offloader().post_init() | |
| if self.server_args.kv_cache_dtype == "fp8_e4m3": | |
| if self.server_args.quantization_param_path is not None: | |
| if callable(getattr(self.model, "load_kv_cache_scales", None)): | |
| self.model.load_kv_cache_scales( | |
| self.server_args.quantization_param_path | |
| ) | |
| logger.info( | |
| "Loaded KV cache scaling factors from %s", | |
| self.server_args.quantization_param_path, | |
| ) | |
| else: | |
| raise RuntimeError( | |
| "Using FP8 KV cache and scaling factors provided but " | |
| "model %s does not support loading scaling factors.", | |
| self.model.__class__, | |
| ) | |
| else: | |
| logger.warning( | |
| "Using FP8 KV cache but no scaling factors " | |
| "provided. Defaulting to scaling factors of 1.0. " | |
| "This may lead to less accurate results!" | |
| ) | |
| # Parse other args | |
| self.sliding_window_size = None | |
| if hasattr(self.model, "get_attention_sliding_window_size"): | |
| self.sliding_window_size = self.model.get_attention_sliding_window_size() | |
| elif self.model_config.attention_chunk_size is not None: | |
| self.sliding_window_size = self.model_config.attention_chunk_size | |
| logger.info( | |
| f"Setting sliding_window_size to be attention_chunk_size: {self.sliding_window_size}" | |
| ) | |
| self.dtype = self.model_config.dtype | |
| after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id) | |
| self.weight_load_mem_usage = before_avail_memory - after_avail_memory | |
| logger.info( | |
| f"Load weight end. " | |
| f"type={type(self.model).__name__}, " | |
| f"dtype={self.dtype}, " | |
| f"avail mem={after_avail_memory:.2f} GB, " | |
| f"mem usage={self.weight_load_mem_usage:.2f} GB." | |
| ) | |
| if self.server_args.elastic_ep_backend == "mooncake": | |
| # Mooncake does not support `monitored_barrier` | |
| dist.barrier(group=get_tp_group().cpu_group) | |
| else: | |
| # Handle the case where some ranks do not finish loading. | |
| try: | |
| dist.monitored_barrier( | |
| group=get_tp_group().cpu_group, | |
| timeout=datetime.timedelta( | |
| seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S | |
| ), | |
| wait_all_ranks=True, | |
| ) | |
| except RuntimeError: | |
| raise ValueError( | |
| f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node." | |
| ) from None | |
| def update_expert_location( | |
| self, | |
| new_expert_location_metadata: ExpertLocationMetadata, | |
| update_layer_ids: List[int], | |
| ): | |
| if ElasticEPStateManager.instance() is not None: | |
| # TODO: refactor the weights update when elastic ep | |
| old_expert_location_metadata = get_global_expert_location_metadata() | |
| assert old_expert_location_metadata is not None | |
| old_expert_location_metadata.update( | |
| new_expert_location_metadata, | |
| update_layer_ids=update_layer_ids, | |
| ) | |
| self.update_weights_from_disk( | |
| self.server_args.model_path, | |
| self.server_args.load_format, | |
| lambda name: "mlp.experts" in name and "mlp.shared_experts" not in name, | |
| ) | |
| else: | |
| self.expert_location_updater.update( | |
| self.model.routed_experts_weights_of_layer, | |
| new_expert_location_metadata, | |
| update_layer_ids=update_layer_ids, | |
| nnodes=self.server_args.nnodes, | |
| rank=self.tp_rank, | |
| ) | |
| def update_weights_from_disk( | |
| self, | |
| model_path: str, | |
| load_format: str, | |
| weight_name_filter: Optional[Callable[[str], bool]] = None, | |
| ) -> tuple[bool, str]: | |
| """Update engine weights in-place from the disk.""" | |
| logger.info( | |
| f"Update engine weights online from disk begin. " | |
| f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" | |
| ) | |
| target_device = torch.device(self.device) | |
| self.model_config.model_path = model_path | |
| load_config = LoadConfig(load_format=load_format) | |
| # Only support DefaultModelLoader for now | |
| loader = get_model_loader(load_config, self.model_config) | |
| if not isinstance(loader, DefaultModelLoader): | |
| message = f"Failed to get model loader: {loader}." | |
| return False, message | |
| def get_weight_iter(config): | |
| iter = loader._get_weights_iterator( | |
| DefaultModelLoader.Source.init_new(config, self.model) | |
| ) | |
| if weight_name_filter is not None: | |
| iter = ( | |
| (name, weight) for name, weight in iter if weight_name_filter(name) | |
| ) | |
| return iter | |
| def model_load_weights(model, iter): | |
| DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device) | |
| return model | |
| with set_default_torch_dtype(self.model_config.dtype): | |
| try: | |
| iter = get_weight_iter(self.model_config) | |
| except Exception as e: | |
| message = f"Failed to get weights iterator: {e}." | |
| return False, message | |
| try: | |
| model = model_load_weights(self.model, iter) | |
| except Exception as e: | |
| message = ( | |
| f"Failed to update weights: {e}.\nRolling back to original weights." | |
| ) | |
| del iter | |
| gc.collect() | |
| iter = get_weight_iter(self.model_config) | |
| self.model = model_load_weights(self.model, iter) | |
| return False, message | |
| self.model = model | |
| self.server_args.model_path = model_path | |
| self.server_args.load_format = load_format | |
| self.load_config = load_config | |
| logger.info("Update weights end.") | |
| return True, "Succeeded to update model weights." | |
| def init_weights_send_group_for_remote_instance( | |
| self, | |
| master_address, | |
| ports, | |
| group_rank, | |
| world_size, | |
| group_name, | |
| backend="nccl", | |
| ): | |
| assert ( | |
| torch.distributed.is_initialized() | |
| ), "Default torch process group must be initialized" | |
| assert group_name != "", "Group name cannot be empty" | |
| ports_list = ports.split(",") | |
| assert ( | |
| len(ports_list) == self.tp_size | |
| ), f"Expected {self.tp_size} ports, but got {len(ports_list)} ports." | |
| group_port = ports_list[self.tp_rank] | |
| group_name = f"{group_name}_{group_port}_{self.tp_rank}" | |
| logger.info( | |
| f"init custom process group: tp_rank={self.tp_rank}, gpu_id={self.gpu_id}, master_address={master_address}, master_port={group_port}, " | |
| f"group_rank={group_rank}, world_size={world_size}, group_name={group_name}, backend={backend}" | |
| ) | |
| torch.cuda.empty_cache() | |
| success = False | |
| message = "" | |
| try: | |
| self._weights_send_group[group_name] = init_custom_process_group( | |
| backend=backend, | |
| init_method=f"tcp://{master_address}:{group_port}", | |
| world_size=world_size, | |
| rank=group_rank, | |
| group_name=group_name, | |
| device_id=torch.device("cuda", self.gpu_id), | |
| ) | |
| dist.barrier(group=self._weights_send_group[group_name]) | |
| success = True | |
| message = ( | |
| f"Succeeded to init group through {master_address}:{group_port} group." | |
| ) | |
| except Exception as e: | |
| message = f"Failed to init group: {e}." | |
| logger.error(message) | |
| torch.cuda.empty_cache() | |
| return success, message | |
| def send_weights_to_remote_instance( | |
| self, | |
| master_address, | |
| ports, | |
| group_name, | |
| ): | |
| assert ( | |
| torch.distributed.is_initialized() | |
| ), "Default torch process group must be initialized" | |
| assert group_name != "", "Group name cannot be empty" | |
| ports_list = ports.split(",") | |
| assert ( | |
| len(ports_list) == self.tp_size | |
| ), f"Expected {self.tp_size} ports, but got {len(ports_list)} ports." | |
| group_port = ports_list[self.tp_rank] | |
| group_name = f"{group_name}_{group_port}_{self.tp_rank}" | |
| if self._weights_send_group[group_name] is not None: | |
| send_group = self._weights_send_group[group_name] | |
| else: | |
| message = f"Group {group_name} not in _weights_send_group list. Please call `init_weights_send_group_for_remote_instance` first." | |
| logger.error(message) | |
| return False, message | |
| torch.cuda.empty_cache() | |
| success = False | |
| message = "" | |
| try: | |
| for _, weights in self.model.named_parameters(): | |
| torch.distributed.broadcast( | |
| weights, | |
| src=0, | |
| group=send_group, | |
| ) | |
| success = True | |
| message = f"Succeeded to send weights through {master_address}:{group_port} {group_name}." | |
| except Exception as e: | |
| message = f"Failed to send weights: {e}." | |
| logger.error(message) | |
| # destroy the process group after sending weights | |
| del self._weights_send_group[group_name] | |
| torch.distributed.distributed_c10d.destroy_process_group(send_group) | |
| torch.cuda.empty_cache() | |
| return success, message | |
| def init_weights_update_group( | |
| self, | |
| master_address, | |
| master_port, | |
| rank_offset, | |
| world_size, | |
| group_name, | |
| backend="nccl", | |
| ): | |
| """Initialize the Torch process group for model parameter updates. | |
| `_model_update_group` is used in the RLHF workflow, where rank | |
| 0 is the actor model in the training engine, and the other ranks are | |
| the inference engine, which is used for rollout. | |
| In the RLHF workflow, the training engine updates the model | |
| weights/parameters online, and broadcasts them to the inference | |
| engine through the `_model_update_group` process group. | |
| """ | |
| assert ( | |
| torch.distributed.is_initialized() | |
| ), "Default torch process group must be initialized" | |
| assert group_name != "", "Group name cannot be empty" | |
| rank = rank_offset + self.tp_rank | |
| logger.info( | |
| f"init custom process group: master_address={master_address}, master_port={master_port}, " | |
| f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}" | |
| ) | |
| try: | |
| self._model_update_group[group_name] = init_custom_process_group( | |
| backend=backend, | |
| init_method=f"tcp://{master_address}:{master_port}", | |
| world_size=world_size, | |
| rank=rank, | |
| group_name=group_name, | |
| ) | |
| return True, "Succeeded to initialize custom process group." | |
| except Exception as e: | |
| message = f"Failed to initialize custom process group: {e}." | |
| logger.error(message) | |
| return False, message | |
| def destroy_weights_update_group(self, group_name): | |
| try: | |
| if group_name in self._model_update_group: | |
| pg = self._model_update_group.pop(group_name) | |
| torch.distributed.destroy_process_group(pg) | |
| return True, "Succeeded to destroy custom process group." | |
| else: | |
| return False, "The group to be destroyed does not exist." | |
| except Exception as e: | |
| message = f"Failed to destroy custom process group: {e}." | |
| logger.error(message) | |
| return False, message | |
| def update_weights_from_distributed(self, names, dtypes, shapes, group_name): | |
| """ | |
| Update specific parameter in the model weights online | |
| through `_model_update_group` process group. | |
| Args: | |
| name: the name of the parameter to be updated. | |
| dtype: the data type of the parameter to be updated. | |
| shape: the shape of the parameter to be updated. | |
| """ | |
| assert group_name in self._model_update_group, ( | |
| f"Group {group_name} not in {list(self._model_update_group.keys())}. " | |
| "Please call `init_weights_update_group` first." | |
| ) | |
| try: | |
| weights = [] | |
| handles = [] | |
| for name, dtype, shape in zip(names, dtypes, shapes): | |
| target_dtype = ( | |
| dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype) | |
| ) | |
| weight = torch.empty(shape, dtype=target_dtype, device=self.device) | |
| handles.append( | |
| torch.distributed.broadcast( | |
| weight, | |
| src=0, | |
| group=self._model_update_group[group_name], | |
| async_op=True, | |
| ) | |
| ) | |
| weights.append((name, weight)) | |
| for handle in handles: | |
| handle.wait() | |
| self.model.load_weights(weights) | |
| return True, "Succeeded to update parameter online." | |
| except Exception as e: | |
| error_msg = ( | |
| f"Failed to update parameter online: {e}. " | |
| f"The full weights of the ModelRunner are partially updated. " | |
| f"Please discard the whole weights." | |
| ) | |
| logger.error(error_msg) | |
| return False, error_msg | |
| def update_weights_from_tensor( | |
| self, | |
| named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]], | |
| load_format: Optional[str] = None, | |
| ): | |
| monkey_patch_torch_reductions() | |
| if load_format == "flattened_bucket": | |
| # Handle flattened bucket format | |
| return self._update_weights_from_flattened_bucket( | |
| flattened_tensor_bucket_dict=named_tensors | |
| ) | |
| # We need to get device after patch otherwise the device would be wrong | |
| self.device_module = torch.get_device_module(self.device) | |
| infered_device = self.device_module.current_device() | |
| named_tensors = [ | |
| (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank, device=infered_device)) | |
| for name, tensor in named_tensors | |
| ] | |
| if load_format == "direct": | |
| _model_load_weights_direct(self.model, named_tensors) | |
| elif load_format in self.server_args.custom_weight_loader: | |
| custom_loader = dynamic_import(load_format) | |
| custom_loader(self.model, named_tensors) | |
| elif load_format is None: | |
| self.model.load_weights(named_tensors) | |
| else: | |
| raise NotImplementedError(f"Unknown load_format={load_format}") | |
| return True, "Success" | |
| def _update_weights_from_flattened_bucket( | |
| self, | |
| flattened_tensor_bucket_dict, | |
| ): | |
| """Handle flattened bucket format for weight updates""" | |
| flattened_tensor = flattened_tensor_bucket_dict["flattened_tensor"] | |
| metadata = flattened_tensor_bucket_dict["metadata"] | |
| # Convert metadata dict to our format | |
| converted_metadata = [] | |
| for meta in metadata: | |
| converted_meta = FlattenedTensorMetadata( | |
| name=meta.name, | |
| shape=meta.shape, | |
| dtype=meta.dtype, | |
| start_idx=meta.start_idx, | |
| end_idx=meta.end_idx, | |
| numel=meta.numel, | |
| ) | |
| converted_metadata.append(converted_meta) | |
| # Create bucket and reconstruct tensors | |
| bucket = FlattenedTensorBucket( | |
| flattened_tensor=flattened_tensor, metadata=converted_metadata | |
| ) | |
| reconstructed_tensors = bucket.reconstruct_tensors() | |
| # Load the reconstructed tensors using the standard method | |
| self.model.load_weights(reconstructed_tensors) | |
| return True, "Success" | |
| def get_weights_by_name( | |
| self, name: str, truncate_size: int = 100 | |
| ) -> Optional[torch.Tensor]: | |
| """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face. | |
| Only used for unit test with an unoptimized performance. | |
| For optimized performance, please use torch.save and torch.load. | |
| """ | |
| # TODO: (chenyang) Add support for Qwen models. | |
| try: | |
| return self.model.get_weights_by_name( | |
| name, truncate_size, tp_size=self.tp_size | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error when getting parameter {name}: {e}") | |
| return None | |
| def init_lora_manager(self): | |
| self.lora_manager = LoRAManager( | |
| base_model=self.model, | |
| base_hf_config=self.model_config.hf_config, | |
| max_loras_per_batch=self.server_args.max_loras_per_batch, | |
| load_config=self.load_config, | |
| dtype=self.dtype, | |
| lora_backend=self.server_args.lora_backend, | |
| tp_size=self.tp_size, | |
| tp_rank=self.tp_rank, | |
| max_lora_rank=self.server_args.max_lora_rank, | |
| target_modules=self.server_args.lora_target_modules, | |
| lora_paths=self.server_args.lora_paths, | |
| server_args=self.server_args, | |
| ) | |
| def load_lora_adapter(self, lora_ref: LoRARef): | |
| """Load a new lora adapter from disk or huggingface.""" | |
| logger.info( | |
| f"LoRA adapter loading starts: {lora_ref}. " | |
| f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" | |
| ) | |
| result = self.lora_manager.load_lora_adapter(lora_ref) | |
| logger.info( | |
| f"LoRA adapter loading completes: {lora_ref}. " | |
| f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" | |
| ) | |
| return result | |
| def unload_lora_adapter(self, lora_ref: LoRARef): | |
| """Unload a lora adapter that was previously loaded during initialization or dynamic loading.""" | |
| logger.info( | |
| f"LoRA adapter unloading starts: {lora_ref}. " | |
| f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" | |
| ) | |
| result = self.lora_manager.unload_lora_adapter(lora_ref) | |
| logger.info( | |
| f"LoRA adapter unloading completes: {lora_ref}. " | |
| f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" | |
| ) | |
| return result | |
| def profile_max_num_token(self, total_gpu_memory: int): | |
| available_gpu_memory = get_available_gpu_memory( | |
| self.device, | |
| self.gpu_id, | |
| distributed=get_world_group().world_size > 1, | |
| cpu_group=get_world_group().cpu_group, | |
| ) | |
| if self.is_draft_worker: | |
| num_layers = getattr( | |
| self.model_config.hf_config, | |
| "num_nextn_predict_layers", | |
| self.num_effective_layers, | |
| ) | |
| elif config := self.mambaish_config: | |
| num_layers = len(config.full_attention_layer_ids) | |
| else: | |
| num_layers = self.num_effective_layers | |
| if self.use_mla_backend: | |
| cell_size = ( | |
| (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim) | |
| * num_layers | |
| * torch._utils._element_size(self.kv_cache_dtype) | |
| ) | |
| # Add indexer KV cache overhead for NSA models (DeepSeek V3.2) | |
| if is_deepseek_nsa(self.model_config.hf_config): | |
| index_head_dim = get_nsa_index_head_dim(self.model_config.hf_config) | |
| indexer_size_per_token = ( | |
| index_head_dim | |
| + index_head_dim // NSATokenToKVPool.quant_block_size * 4 | |
| ) | |
| element_size = torch._utils._element_size( | |
| NSATokenToKVPool.index_k_with_scale_buffer_dtype | |
| ) | |
| cell_size += indexer_size_per_token * num_layers * element_size | |
| else: | |
| cell_size = ( | |
| self.model_config.get_num_kv_heads(get_attention_tp_size()) | |
| * self.model_config.head_dim | |
| * num_layers | |
| * 2 | |
| * torch._utils._element_size(self.kv_cache_dtype) | |
| ) | |
| rest_memory = available_gpu_memory - total_gpu_memory * ( | |
| 1 - self.mem_fraction_static | |
| ) | |
| if self.mambaish_config is not None: | |
| rest_memory = self.handle_max_mamba_cache(rest_memory) | |
| max_num_token = int(rest_memory * (1 << 30) // cell_size) | |
| return max_num_token | |
| def handle_max_mamba_cache(self, total_rest_memory): | |
| config = self.mambaish_config | |
| server_args = self.server_args | |
| assert config is not None | |
| speculativa_ratio = ( | |
| 0 | |
| if server_args.speculative_num_draft_tokens is None | |
| else server_args.speculative_num_draft_tokens | |
| ) | |
| if ( | |
| server_args.disable_radix_cache | |
| or config.mamba2_cache_params.mamba_cache_per_req == 0 | |
| ): | |
| # with disable radix cache, sets the max_mamba_cache_size based on the max_running_requests | |
| if server_args.max_mamba_cache_size is None: | |
| if server_args.max_running_requests is not None: | |
| server_args.max_mamba_cache_size = server_args.max_running_requests | |
| else: | |
| server_args.max_mamba_cache_size = 512 | |
| else: | |
| # allocate the memory based on the ratio between mamba state memory vs. full kv cache memory | |
| # solve the equations: | |
| # 1. mamba_state_memory + full_kv_cache_memory == total_rest_memory | |
| # 2. mamba_state_memory / full_kv_cache_memory == server_args.mamba_full_memory_ratio | |
| mamba_state_memory_raw = ( | |
| total_rest_memory | |
| * server_args.mamba_full_memory_ratio | |
| / (1 + server_args.mamba_full_memory_ratio) | |
| ) | |
| # calculate the max_mamba_cache_size based on the given total mamba memory | |
| server_args.max_mamba_cache_size = int( | |
| (mamba_state_memory_raw * (1 << 30)) | |
| // config.mamba2_cache_params.mamba_cache_per_req | |
| // (1 + speculativa_ratio) | |
| ) | |
| if self.hybrid_gdn_config is not None: | |
| server_args.max_mamba_cache_size = server_args.max_mamba_cache_size // ( | |
| server_args.dp_size if server_args.enable_dp_attention else 1 | |
| ) | |
| mamba_state_memory = ( | |
| server_args.max_mamba_cache_size | |
| * config.mamba2_cache_params.mamba_cache_per_req | |
| * (1 + speculativa_ratio) | |
| / (1 << 30) | |
| ) | |
| return total_rest_memory - mamba_state_memory | |
| def hybrid_gdn_config(self): | |
| config = self.model_config.hf_config | |
| if isinstance(config, Qwen3NextConfig): | |
| return config | |
| return None | |
| def mamba2_config(self): | |
| config = self.model_config.hf_config | |
| if isinstance(config, FalconH1Config | NemotronHConfig): | |
| return config | |
| return None | |
| def mambaish_config(self): | |
| return self.mamba2_config or self.hybrid_gdn_config | |
| def set_num_token_hybrid(self): | |
| if ( | |
| "Llama4ForConditionalGeneration" | |
| in self.model_config.hf_config.architectures | |
| ): | |
| temp_ratio = ( | |
| (1 - self.is_hybrid) | |
| + self.is_hybrid | |
| * self.attention_chunk_size | |
| / self.model_config.context_len | |
| ) | |
| self.swa_max_total_num_tokens = ( | |
| 4 * self.max_total_num_tokens * temp_ratio // (3 * temp_ratio + 1) | |
| ) | |
| self.full_max_total_num_tokens = ( | |
| 4 * self.max_total_num_tokens | |
| - 12 * self.max_total_num_tokens * temp_ratio // (3 * temp_ratio + 1) | |
| ) | |
| self.swa_max_total_num_tokens = int( | |
| self.swa_max_total_num_tokens | |
| // self.server_args.page_size | |
| * self.server_args.page_size | |
| ) | |
| self.full_max_total_num_tokens = int( | |
| self.full_max_total_num_tokens | |
| // self.server_args.page_size | |
| * self.server_args.page_size | |
| ) | |
| self.max_total_num_tokens = self.full_max_total_num_tokens | |
| else: | |
| assert self.sliding_window_size is not None and self.sliding_window_size > 0 | |
| full_attention_layer_ids = [] | |
| swa_attention_layer_ids = [] | |
| try: | |
| layers = self.model.model.layers | |
| except: | |
| try: | |
| layers = self.model.language_model.model.layers | |
| except: | |
| try: | |
| layers = self.model.language_model.layers | |
| except: | |
| self.is_hybrid = False | |
| return | |
| for layer in layers: | |
| if ( | |
| layer.self_attn.attn.sliding_window_size is None | |
| or layer.self_attn.attn.sliding_window_size == -1 | |
| ): | |
| full_attention_layer_ids.append(layer.layer_id) | |
| else: | |
| swa_attention_layer_ids.append(layer.layer_id) | |
| self.model_config.swa_attention_layer_ids = swa_attention_layer_ids | |
| self.model_config.full_attention_layer_ids = full_attention_layer_ids | |
| # Algorithm: | |
| # Existing max_total_num_tokens is per layer and assume all layers have the same number of tokens. | |
| # - Find total # of tokens available across layers. | |
| # - Calculate full_max_total_num_tokens and swa_max_total_num_tokens based on the given swa_full_tokens_ratio. | |
| total_tokens = ( | |
| self.max_total_num_tokens * self.model_config.num_hidden_layers | |
| ) | |
| full_layers_num = len(full_attention_layer_ids) | |
| swa_layers_num = len(swa_attention_layer_ids) | |
| swa_full_tokens_ratio = self.server_args.swa_full_tokens_ratio | |
| # Solve the equations: | |
| # 1. swa_max_total_num_tokens * swa_layers_num + full_max_total_num_tokens * full_layers_num == total_tokens | |
| # 2. full_max_total_num_tokens * swa_full_tokens_ratio == swa_max_total_num_tokens | |
| denominator = swa_full_tokens_ratio * swa_layers_num + full_layers_num | |
| self.full_max_total_num_tokens = int(total_tokens / denominator) | |
| self.swa_max_total_num_tokens = int( | |
| self.full_max_total_num_tokens * swa_full_tokens_ratio | |
| ) | |
| self.max_total_num_tokens = self.full_max_total_num_tokens | |
| logger.info( | |
| f"Use Sliding window memory pool. full_layer_tokens={self.full_max_total_num_tokens}, swa_layer_tokens={self.swa_max_total_num_tokens}" | |
| ) | |
| def can_run_piecewise_cuda_graph(self): | |
| if self.server_args.disable_cuda_graph: | |
| log_info_on_rank0( | |
| logger, "Disable piecewise CUDA graph because disable_cuda_graph is set" | |
| ) | |
| return False | |
| if self.server_args.enable_torch_compile: | |
| log_info_on_rank0( | |
| logger, | |
| "Disable piecewise CUDA graph because piecewise_cuda_graph has conflict with torch compile", | |
| ) | |
| return False | |
| if self.pp_size > 1: | |
| # TODO(yuwei): support PP | |
| log_info_on_rank0( | |
| logger, | |
| "Disable piecewise CUDA graph because piecewise_cuda_graph does not support PP", | |
| ) | |
| return False | |
| return True | |
| def init_memory_pool( | |
| self, | |
| total_gpu_memory: int, | |
| max_num_reqs: Optional[int] = None, | |
| max_total_tokens: Optional[int] = None, | |
| ): | |
| # Determine the kv cache dtype | |
| if self.server_args.kv_cache_dtype == "auto": | |
| quant_config = getattr(self.model, "quant_config", None) | |
| kv_cache_quant_algo = getattr(quant_config, "kv_cache_quant_algo", None) | |
| if ( | |
| isinstance(kv_cache_quant_algo, str) | |
| and kv_cache_quant_algo.upper() == "FP8" | |
| ): | |
| if _is_hip: | |
| self.kv_cache_dtype = torch.float8_e4m3fnuz | |
| else: | |
| self.kv_cache_dtype = torch.float8_e4m3fn | |
| else: | |
| self.kv_cache_dtype = self.dtype | |
| elif self.server_args.kv_cache_dtype == "fp8_e5m2": | |
| if _is_hip: # Using natively supported format | |
| self.kv_cache_dtype = torch.float8_e5m2fnuz | |
| else: | |
| self.kv_cache_dtype = torch.float8_e5m2 | |
| elif self.server_args.kv_cache_dtype == "fp8_e4m3": | |
| if _is_hip: # Using natively supported format | |
| self.kv_cache_dtype = torch.float8_e4m3fnuz | |
| else: | |
| self.kv_cache_dtype = torch.float8_e4m3fn | |
| elif self.server_args.kv_cache_dtype in ("bf16", "bfloat16"): | |
| self.kv_cache_dtype = torch.bfloat16 | |
| else: | |
| raise ValueError( | |
| f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}." | |
| ) | |
| log_info_on_rank0(logger, f"Using KV cache dtype: {self.kv_cache_dtype}") | |
| self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory) | |
| if SGLANG_CI_SMALL_KV_SIZE: | |
| self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE) | |
| if max_num_reqs is None: | |
| max_num_reqs = min( | |
| max( | |
| int( | |
| self.max_total_num_tokens / self.model_config.context_len * 512 | |
| ), | |
| 2048, | |
| ), | |
| 4096, | |
| ) | |
| if self.mambaish_config is not None: | |
| ratio = ( | |
| MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO | |
| if not self.server_args.disable_radix_cache | |
| else 1 | |
| ) | |
| max_num_reqs = min( | |
| max_num_reqs, self.server_args.max_mamba_cache_size // ratio | |
| ) | |
| if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone(): | |
| if self.is_draft_worker: | |
| self.max_total_num_tokens = self.server_args.draft_runner_cache_size | |
| max_num_reqs = self.server_args.max_num_reqs | |
| else: | |
| # We are sharing the `token_to_kv_pool`, and both verify and draft tokens | |
| # can be concurrently allocated, so we should give a headroom for it. | |
| self.server_args.draft_runner_cache_size = ( | |
| self.max_total_num_tokens | |
| # draft | |
| + max_num_reqs | |
| * self.server_args.speculative_num_steps | |
| * self.server_args.speculative_eagle_topk | |
| # verify | |
| + max_num_reqs * self.server_args.speculative_num_draft_tokens | |
| # buffer | |
| + 100 | |
| ) | |
| # Target worker and draft worker shares the same indices for the | |
| # token_to_kv_pool, so we should make sure to match max_total_num_tokens. | |
| self.max_total_num_tokens = self.server_args.draft_runner_cache_size | |
| self.server_args.max_num_reqs = max_num_reqs | |
| if max_total_tokens is not None: | |
| if max_total_tokens > self.max_total_num_tokens: | |
| logging.warning( | |
| f"max_total_tokens={max_total_tokens} is larger than the profiled value " | |
| f"{self.max_total_num_tokens}. " | |
| f"Use the profiled value instead." | |
| ) | |
| self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens) | |
| self.max_total_num_tokens = ( | |
| self.max_total_num_tokens | |
| // self.server_args.page_size | |
| * self.server_args.page_size | |
| ) | |
| # different pp rank may have different num of layers, so we need to reduce the max_total_num_tokens | |
| if self.pp_size > 1: | |
| tensor = torch.tensor(self.max_total_num_tokens, dtype=torch.int64) | |
| torch.distributed.all_reduce( | |
| tensor, | |
| op=torch.distributed.ReduceOp.MIN, | |
| group=get_world_group().cpu_group, | |
| ) | |
| self.max_total_num_tokens = tensor.item() | |
| # create token size for hybrid cache | |
| if self.is_hybrid: | |
| self.set_num_token_hybrid() | |
| if self.max_total_num_tokens <= 0: | |
| raise RuntimeError( | |
| f"Not enough memory. Please try to increase --mem-fraction-static. " | |
| f"Current value: {self.server_args.mem_fraction_static=}" | |
| ) | |
| # Initialize req_to_token_pool | |
| if self.req_to_token_pool is None: | |
| # FIXME(lsyin): this is the temporary fix for the context length issue when using speculative decoding | |
| extra_max_context_len = 4 | |
| if self.server_args.speculative_num_draft_tokens is not None: | |
| extra_max_context_len += self.server_args.speculative_num_draft_tokens | |
| if self.server_args.disaggregation_mode == "decode": | |
| from sglang.srt.disaggregation.decode import ( | |
| DecodeReqToTokenPool, | |
| HybridMambaDecodeReqToTokenPool, | |
| ) | |
| # subscribe memory for pre-allocated requests | |
| # if max_num_reqs <= 32, we pre-allocate 2x requests | |
| pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0 | |
| if config := self.mambaish_config: | |
| self.req_to_token_pool = HybridMambaDecodeReqToTokenPool( | |
| size=max_num_reqs, | |
| max_context_len=self.model_config.context_len | |
| + extra_max_context_len, | |
| device=self.device, | |
| enable_memory_saver=self.server_args.enable_memory_saver, | |
| cache_params=config.mamba2_cache_params, | |
| speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens, | |
| pre_alloc_size=pre_alloc_size, | |
| ) | |
| else: | |
| self.req_to_token_pool = DecodeReqToTokenPool( | |
| size=max_num_reqs, | |
| max_context_len=self.model_config.context_len | |
| + extra_max_context_len, | |
| device=self.device, | |
| enable_memory_saver=self.server_args.enable_memory_saver, | |
| pre_alloc_size=pre_alloc_size, | |
| ) | |
| elif config := self.mambaish_config: | |
| self.req_to_token_pool = HybridReqToTokenPool( | |
| size=max_num_reqs, | |
| mamba_size=self.server_args.max_mamba_cache_size, | |
| max_context_len=self.model_config.context_len | |
| + extra_max_context_len, | |
| device=self.device, | |
| enable_memory_saver=self.server_args.enable_memory_saver, | |
| cache_params=config.mamba2_cache_params, | |
| speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens, | |
| ) | |
| else: | |
| self.req_to_token_pool = ReqToTokenPool( | |
| size=max_num_reqs, | |
| max_context_len=self.model_config.context_len | |
| + extra_max_context_len, | |
| device=self.device, | |
| enable_memory_saver=self.server_args.enable_memory_saver, | |
| ) | |
| else: | |
| # Draft worker shares req_to_token_pool with the target worker. | |
| assert self.is_draft_worker | |
| # Initialize token_to_kv_pool | |
| is_nsa_model = is_deepseek_nsa(self.model_config.hf_config) | |
| if self.server_args.attention_backend == "ascend": | |
| if self.use_mla_backend: | |
| self.token_to_kv_pool = AscendMLAPagedTokenToKVPool( | |
| self.max_total_num_tokens, | |
| page_size=self.page_size, | |
| dtype=self.kv_cache_dtype, | |
| kv_lora_rank=self.model_config.kv_lora_rank, | |
| qk_rope_head_dim=self.model_config.qk_rope_head_dim, | |
| index_head_dim=self.model_config.index_head_dim, | |
| layer_num=self.num_effective_layers, | |
| device=self.device, | |
| enable_memory_saver=self.server_args.enable_memory_saver, | |
| start_layer=self.start_layer, | |
| end_layer=self.end_layer, | |
| ) | |
| else: | |
| self.token_to_kv_pool = AscendTokenToKVPool( | |
| self.max_total_num_tokens, | |
| page_size=self.page_size, | |
| dtype=self.kv_cache_dtype, | |
| head_num=self.model_config.get_num_kv_heads( | |
| get_attention_tp_size() | |
| ), | |
| head_dim=self.model_config.head_dim, | |
| layer_num=self.model_config.num_hidden_layers, | |
| device=self.device, | |
| enable_memory_saver=self.server_args.enable_memory_saver, | |
| ) | |
| elif self.use_mla_backend and is_nsa_model: | |
| self.token_to_kv_pool = NSATokenToKVPool( | |
| self.max_total_num_tokens, | |
| page_size=self.page_size, | |
| dtype=self.kv_cache_dtype, | |
| kv_lora_rank=self.model_config.kv_lora_rank, | |
| qk_rope_head_dim=self.model_config.qk_rope_head_dim, | |
| layer_num=self.num_effective_layers, | |
| device=self.device, | |
| enable_memory_saver=self.server_args.enable_memory_saver, | |
| start_layer=self.start_layer, | |
| end_layer=self.end_layer, | |
| index_head_dim=get_nsa_index_head_dim(self.model_config.hf_config), | |
| ) | |
| elif self.use_mla_backend: | |
| assert not is_nsa_model | |
| self.token_to_kv_pool = MLATokenToKVPool( | |
| self.max_total_num_tokens, | |
| page_size=self.page_size, | |
| dtype=self.kv_cache_dtype, | |
| kv_lora_rank=self.model_config.kv_lora_rank, | |
| qk_rope_head_dim=self.model_config.qk_rope_head_dim, | |
| layer_num=self.num_effective_layers, | |
| device=self.device, | |
| enable_memory_saver=self.server_args.enable_memory_saver, | |
| start_layer=self.start_layer, | |
| end_layer=self.end_layer, | |
| ) | |
| elif self.server_args.enable_double_sparsity: | |
| self.token_to_kv_pool = DoubleSparseTokenToKVPool( | |
| self.max_total_num_tokens, | |
| page_size=self.page_size, | |
| dtype=self.kv_cache_dtype, | |
| head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()), | |
| head_dim=self.model_config.head_dim, | |
| layer_num=self.num_effective_layers, | |
| device=self.device, | |
| heavy_channel_num=self.server_args.ds_heavy_channel_num, | |
| enable_memory_saver=self.server_args.enable_memory_saver, | |
| start_layer=self.start_layer, | |
| end_layer=self.end_layer, | |
| ) | |
| else: | |
| if self.is_hybrid: | |
| self.token_to_kv_pool = SWAKVPool( | |
| size=self.full_max_total_num_tokens, | |
| size_swa=self.swa_max_total_num_tokens, | |
| dtype=self.kv_cache_dtype, | |
| head_num=self.model_config.get_num_kv_heads( | |
| get_attention_tp_size() | |
| ), | |
| head_dim=self.model_config.head_dim, | |
| swa_attention_layer_ids=self.model_config.swa_attention_layer_ids, | |
| full_attention_layer_ids=self.model_config.full_attention_layer_ids, | |
| enable_kvcache_transpose=False, | |
| device=self.device, | |
| ) | |
| elif config := self.mambaish_config: | |
| self.token_to_kv_pool = HybridLinearKVPool( | |
| page_size=self.page_size, | |
| size=self.max_total_num_tokens, | |
| dtype=self.kv_cache_dtype, | |
| head_num=self.model_config.get_num_kv_heads( | |
| get_attention_tp_size() | |
| ), | |
| head_dim=self.model_config.head_dim, | |
| # if draft worker, we only need 1 attention layer's kv pool | |
| full_attention_layer_ids=( | |
| [0] if self.is_draft_worker else config.full_attention_layer_ids | |
| ), | |
| enable_kvcache_transpose=False, | |
| device=self.device, | |
| mamba_pool=self.req_to_token_pool.mamba_pool, | |
| ) | |
| else: | |
| self.token_to_kv_pool = MHATokenToKVPool( | |
| self.max_total_num_tokens, | |
| page_size=self.page_size, | |
| dtype=self.kv_cache_dtype, | |
| head_num=self.model_config.get_num_kv_heads( | |
| get_attention_tp_size() | |
| ), | |
| head_dim=self.model_config.head_dim, | |
| layer_num=self.num_effective_layers, | |
| device=self.device, | |
| enable_memory_saver=self.server_args.enable_memory_saver, | |
| start_layer=self.start_layer, | |
| end_layer=self.end_layer, | |
| enable_kv_cache_copy=( | |
| self.server_args.speculative_algorithm is not None | |
| ), | |
| ) | |
| # Initialize token_to_kv_pool_allocator | |
| need_sort = self.server_args.disaggregation_mode in ("decode", "prefill") | |
| if self.token_to_kv_pool_allocator is None: | |
| if _is_npu and ( | |
| self.server_args.attention_backend == "ascend" | |
| or self.hybrid_gdn_config is not None | |
| ): | |
| self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator( | |
| self.max_total_num_tokens, | |
| page_size=self.page_size, | |
| dtype=self.kv_cache_dtype, | |
| device=self.device, | |
| kvcache=self.token_to_kv_pool, | |
| need_sort=need_sort, | |
| ) | |
| else: | |
| if self.page_size == 1: | |
| if self.is_hybrid: | |
| self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator( | |
| self.full_max_total_num_tokens, | |
| self.swa_max_total_num_tokens, | |
| dtype=self.kv_cache_dtype, | |
| device=self.device, | |
| kvcache=self.token_to_kv_pool, | |
| need_sort=need_sort, | |
| ) | |
| else: | |
| self.token_to_kv_pool_allocator = TokenToKVPoolAllocator( | |
| self.max_total_num_tokens, | |
| dtype=self.kv_cache_dtype, | |
| device=self.device, | |
| kvcache=self.token_to_kv_pool, | |
| need_sort=need_sort, | |
| ) | |
| else: | |
| assert not self.is_hybrid | |
| self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator( | |
| self.max_total_num_tokens, | |
| page_size=self.page_size, | |
| dtype=self.kv_cache_dtype, | |
| device=self.device, | |
| kvcache=self.token_to_kv_pool, | |
| need_sort=need_sort, | |
| ) | |
| else: | |
| assert self.is_draft_worker | |
| logger.info( | |
| f"Memory pool end. " | |
| f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" | |
| ) | |
| def init_cublas(self): | |
| """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later.""" | |
| dtype = torch.float16 | |
| device = "cuda" | |
| a = torch.ones((16, 16), dtype=dtype, device=device) | |
| b = torch.ones((16, 16), dtype=dtype, device=device) | |
| c = a @ b | |
| return c | |
| def init_attention_backend(self): | |
| """Init attention kernel backend.""" | |
| if self.server_args.enable_two_batch_overlap and not self.is_draft_worker: | |
| self.attn_backend = TboAttnBackend.init_new(self._get_attention_backend) | |
| else: | |
| self.attn_backend = self._get_attention_backend() | |
| def _get_attention_backend(self): | |
| """Init attention kernel backend.""" | |
| self.prefill_attention_backend_str, self.decode_attention_backend_str = ( | |
| self.server_args.get_attention_backends() | |
| ) | |
| if self.decode_attention_backend_str != self.prefill_attention_backend_str: | |
| from sglang.srt.layers.attention.hybrid_attn_backend import ( | |
| HybridAttnBackend, | |
| ) | |
| attn_backend = HybridAttnBackend( | |
| self, | |
| decode_backend=self._get_attention_backend_from_str( | |
| self.decode_attention_backend_str | |
| ), | |
| prefill_backend=self._get_attention_backend_from_str( | |
| self.prefill_attention_backend_str | |
| ), | |
| ) | |
| logger.info( | |
| f"Using hybrid attention backend for decode and prefill: " | |
| f"decode_backend={self.decode_attention_backend_str}, " | |
| f"prefill_backend={self.prefill_attention_backend_str}." | |
| ) | |
| logger.warning( | |
| "Warning: Attention backend specified by --attention-backend or default backend might be overridden." | |
| "The feature of hybrid attention backend is experimental and unstable. Please raise an issue if you encounter any problem." | |
| ) | |
| else: | |
| attn_backend = self._get_attention_backend_from_str( | |
| self.server_args.attention_backend | |
| ) | |
| ( | |
| get_global_server_args().prefill_attention_backend, | |
| get_global_server_args().decode_attention_backend, | |
| ) = (self.prefill_attention_backend_str, self.decode_attention_backend_str) | |
| return attn_backend | |
| def _get_attention_backend_from_str(self, backend_str: str): | |
| if backend_str not in ATTENTION_BACKENDS: | |
| raise ValueError(f"Invalid attention backend: {backend_str}") | |
| full_attention_backend = ATTENTION_BACKENDS[backend_str](self) | |
| return attn_backend_wrapper(self, full_attention_backend) | |
| def init_double_sparsity_channel_config(self, selected_channel): | |
| selected_channel = "." + selected_channel + "_proj" | |
| self.sorted_channels = [] | |
| # load channel config | |
| with open(self.server_args.ds_channel_config_path, "r") as f: | |
| channel_config = json.load(f) | |
| for i in range(self.start_layer, self.end_layer): | |
| key = "model.layers." + str(i) + ".self_attn" + selected_channel | |
| self.sorted_channels.append( | |
| torch.tensor(channel_config[key])[ | |
| :, : self.server_args.ds_heavy_channel_num | |
| ] | |
| .contiguous() | |
| .cuda() | |
| ) | |
| def init_device_graphs(self): | |
| """Capture device graphs.""" | |
| self.graph_runner = None | |
| self.graph_mem_usage = 0 | |
| if not self.is_generation: | |
| # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models | |
| return | |
| if self.device != "cpu" and self.server_args.disable_cuda_graph: | |
| return | |
| if self.device == "cpu" and not self.server_args.enable_torch_compile: | |
| return | |
| tic = time.perf_counter() | |
| before_mem = get_available_gpu_memory(self.device, self.gpu_id) | |
| logger.info( | |
| f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} begin. This can take up to several minutes. avail mem={before_mem:.2f} GB" | |
| ) | |
| graph_runners = defaultdict( | |
| lambda: CudaGraphRunner, | |
| { | |
| "cpu": CPUGraphRunner, | |
| "npu": NPUGraphRunner, | |
| }, | |
| ) | |
| self.graph_runner = graph_runners[self.device](self) | |
| after_mem = get_available_gpu_memory(self.device, self.gpu_id) | |
| self.graph_mem_usage = before_mem - after_mem | |
| logger.info( | |
| f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} end. Time elapsed: {time.perf_counter() - tic:.2f} s. " | |
| f"mem usage={self.graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB." | |
| ) | |
| def init_threads_binding(self): | |
| omp_cpuids = os.environ.get("SGLANG_CPU_OMP_THREADS_BIND", "all") | |
| cpu_ids_by_node = get_cpu_ids_by_node() | |
| n_numa_node = len(cpu_ids_by_node) | |
| if omp_cpuids == "all": | |
| assert self.tp_size <= n_numa_node, ( | |
| f"SGLANG_CPU_OMP_THREADS_BIND is not set, in this case, " | |
| f"tp_size {self.tp_size} should be smaller than or equal to number of numa node on the machine {n_numa_node}. " | |
| f"If you need tp_size to be larger than number of numa node, please set the CPU cores for each tp rank via SGLANG_CPU_OMP_THREADS_BIND explicitly. " | |
| f"For example, on a machine with 2 numa nodes, where core 0-31 are on numa node 0 and core 32-63 are on numa node 1, " | |
| f"it is suggested to use -tp 2 and bind tp rank 0 to core 0-31 and tp rank 1 to core 32-63. " | |
| f"This is the default behavior if SGLANG_CPU_OMP_THREADS_BIND is not set and it is the same as setting SGLANG_CPU_OMP_THREADS_BIND=0-31|32-63. " | |
| f"If you do need tp_size to be larger than the number of numa nodes, you could set SGLANG_CPU_OMP_THREADS_BIND explicitly for example SGLANG_CPU_OMP_THREADS_BIND=0-15|16-31|32-47|48-63 and run with -tp 4. " | |
| f"If you don't want each tp rank to use all the cores on one numa node, you could set for example SGLANG_CPU_OMP_THREADS_BIND=0-15|32-47 and run with -tp 2." | |
| ) | |
| if self.tp_size < n_numa_node: | |
| logger.warning( | |
| f"Detected the current machine has {n_numa_node} numa nodes available, but tp_size is set to {self.tp_size}, so only {self.tp_size} numa nodes are used." | |
| ) | |
| self.local_omp_cpuid = cpu_ids_by_node[self.tp_rank] | |
| else: | |
| threads_bind_list = omp_cpuids.split("|") | |
| assert self.tp_size == len(threads_bind_list), ( | |
| f"SGLANG_CPU_OMP_THREADS_BIND setting must be aligned with TP size parameter ({self.tp_size}). " | |
| f"Please double check your settings." | |
| ) | |
| self.local_omp_cpuid = threads_bind_list[self.tp_rank] | |
| if self.tp_size > n_numa_node: | |
| logger.warning( | |
| f"TP size ({self.tp_size})is larger than numa node number ({n_numa_node}), " | |
| f"in this case the available memory amount of each rank cannot be determined in prior. " | |
| f"Please set proper `--max-total-tokens` to avoid the out-of-memory error." | |
| ) | |
| def apply_torch_tp(self): | |
| logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.") | |
| from sglang.srt.layers.model_parallel import tensor_parallel | |
| device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,)) | |
| tensor_parallel(self.model, device_mesh) | |
| def forward_decode( | |
| self, | |
| forward_batch: ForwardBatch, | |
| skip_attn_backend_init: bool = False, | |
| pp_proxy_tensors=None, | |
| ) -> LogitsProcessorOutput: | |
| if not skip_attn_backend_init: | |
| self.attn_backend.init_forward_metadata(forward_batch) | |
| # FIXME: add pp_proxy_tensors arg to all models | |
| kwargs = {} | |
| if self.support_pp: | |
| kwargs["pp_proxy_tensors"] = pp_proxy_tensors | |
| return self.model.forward( | |
| forward_batch.input_ids, | |
| forward_batch.positions, | |
| forward_batch, | |
| **kwargs, | |
| ) | |
| def forward_extend( | |
| self, | |
| forward_batch: ForwardBatch, | |
| skip_attn_backend_init: bool = False, | |
| pp_proxy_tensors=None, | |
| ) -> LogitsProcessorOutput: | |
| if not skip_attn_backend_init: | |
| self.attn_backend.init_forward_metadata(forward_batch) | |
| kwargs = {} | |
| if self.support_pp: | |
| kwargs["pp_proxy_tensors"] = pp_proxy_tensors | |
| if forward_batch.input_embeds is not None: | |
| kwargs["input_embeds"] = forward_batch.input_embeds.bfloat16() | |
| if not self.is_generation: | |
| kwargs["get_embedding"] = True | |
| if self.piecewise_cuda_graph_runner is not None: | |
| if self.piecewise_cuda_graph_runner.can_run(forward_batch): | |
| return self.piecewise_cuda_graph_runner.replay(forward_batch, **kwargs) | |
| return self.model.forward( | |
| forward_batch.input_ids, | |
| forward_batch.positions, | |
| forward_batch, | |
| **kwargs, | |
| ) | |
| def forward_idle( | |
| self, forward_batch: ForwardBatch, pp_proxy_tensors=None | |
| ) -> LogitsProcessorOutput: | |
| kwargs = {} | |
| if self.support_pp: | |
| kwargs["pp_proxy_tensors"] = pp_proxy_tensors | |
| return self.model.forward( | |
| forward_batch.input_ids, | |
| forward_batch.positions, | |
| forward_batch, | |
| **kwargs, | |
| ) | |
| def forward_split_prefill( | |
| self, | |
| forward_batch: ForwardBatch, | |
| reinit_attn_backend: bool = False, | |
| forward_count: int = 1, | |
| ) -> LogitsProcessorOutput: | |
| if forward_batch.split_index == 0 or reinit_attn_backend: | |
| self.attn_backend.init_forward_metadata(forward_batch) | |
| next_split_index = min( | |
| forward_batch.split_index + forward_count, | |
| self.model_config.num_hidden_layers, | |
| ) | |
| ret = self.model.forward_split_prefill( | |
| forward_batch.input_ids, | |
| forward_batch.positions, | |
| forward_batch, | |
| (forward_batch.split_index, next_split_index), | |
| ) | |
| forward_batch.split_index = next_split_index | |
| return ret | |
| def forward( | |
| self, | |
| forward_batch: ForwardBatch, | |
| skip_attn_backend_init: bool = False, | |
| pp_proxy_tensors: Optional[PPProxyTensors] = None, | |
| reinit_attn_backend: bool = False, | |
| split_forward_count: int = 1, | |
| ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]: | |
| self.forward_pass_id += 1 | |
| with get_global_expert_distribution_recorder().with_forward_pass( | |
| self.forward_pass_id, | |
| forward_batch, | |
| ): | |
| output = self._forward_raw( | |
| forward_batch, | |
| skip_attn_backend_init, | |
| pp_proxy_tensors, | |
| reinit_attn_backend, | |
| split_forward_count, | |
| ) | |
| if self.eplb_manager is not None: | |
| self.eplb_manager.on_forward_pass_end() | |
| return output | |
| def _forward_raw( | |
| self, | |
| forward_batch: ForwardBatch, | |
| skip_attn_backend_init: bool, | |
| pp_proxy_tensors: Optional[PPProxyTensors], | |
| reinit_attn_backend: bool = False, | |
| split_forward_count: int = 1, | |
| ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]: | |
| mode_check = ( | |
| forward_batch.forward_mode.is_cpu_graph | |
| if self.device == "cpu" | |
| else forward_batch.forward_mode.is_cuda_graph | |
| ) | |
| can_run_graph = bool( | |
| mode_check() | |
| and self.graph_runner | |
| and self.graph_runner.can_run(forward_batch) | |
| ) | |
| if can_run_graph: | |
| ret = self.graph_runner.replay( | |
| forward_batch, | |
| skip_attn_backend_init=skip_attn_backend_init, | |
| pp_proxy_tensors=pp_proxy_tensors, | |
| ) | |
| return ret, can_run_graph | |
| # For MLP sync | |
| if forward_batch.global_num_tokens_cpu is not None: | |
| forward_batch.prepare_mlp_sync_batch(self) | |
| if forward_batch.forward_mode.is_decode(): | |
| ret = self.forward_decode( | |
| forward_batch, | |
| skip_attn_backend_init=skip_attn_backend_init, | |
| pp_proxy_tensors=pp_proxy_tensors, | |
| ) | |
| elif forward_batch.forward_mode.is_extend(): | |
| ret = self.forward_extend( | |
| forward_batch, | |
| skip_attn_backend_init=skip_attn_backend_init, | |
| pp_proxy_tensors=pp_proxy_tensors, | |
| ) | |
| elif forward_batch.forward_mode.is_split_prefill(): | |
| ret = self.forward_split_prefill( | |
| forward_batch, | |
| reinit_attn_backend=reinit_attn_backend, | |
| forward_count=split_forward_count, | |
| ) | |
| elif forward_batch.forward_mode.is_idle(): | |
| ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors) | |
| else: | |
| raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}") | |
| if ( | |
| forward_batch.global_num_tokens_cpu is not None | |
| and self.pp_group.is_last_rank | |
| ): | |
| forward_batch.post_forward_mlp_sync_batch(ret) | |
| return ret, can_run_graph | |
| def _preprocess_logits( | |
| self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo | |
| ): | |
| # NOTE: In overlap mode, the function update_regex_vocab_mask (in sample) | |
| # was executed after we processed last batch's results. | |
| # Calculate logits bias and apply it to next_token_logits. | |
| sampling_info.update_regex_vocab_mask() | |
| sampling_info.apply_logits_bias(logits_output.next_token_logits) | |
| def sample( | |
| self, | |
| logits_output: LogitsProcessorOutput, | |
| forward_batch: ForwardBatch, | |
| ) -> torch.Tensor: | |
| """Sample and compute logprobs and update logits_output. | |
| Args: | |
| logits_output: The logits output from the model forward | |
| forward_batch: The forward batch that generates logits_output | |
| Returns: | |
| A list of next_token_ids | |
| """ | |
| # For duplex models with multiple output streams. | |
| if isinstance(logits_output, tuple): | |
| return torch.stack( | |
| [self.sample(values, forward_batch) for values in logits_output], | |
| axis=-1, | |
| ) | |
| self._preprocess_logits(logits_output, forward_batch.sampling_info) | |
| # Sample the next tokens | |
| next_token_ids = self.sampler( | |
| logits_output, | |
| forward_batch.sampling_info, | |
| forward_batch.return_logprob, | |
| forward_batch.top_logprobs_nums, | |
| forward_batch.token_ids_logprobs, | |
| # For prefill, we only use the position of the last token. | |
| ( | |
| forward_batch.positions | |
| if forward_batch.forward_mode.is_decode() | |
| else forward_batch.seq_lens - 1 | |
| ), | |
| ) | |
| return next_token_ids | |
| def compute_logprobs_only( | |
| self, | |
| logits_output: LogitsProcessorOutput, | |
| forward_batch: ForwardBatch, | |
| ) -> None: | |
| """ | |
| Compute token_ids_logprobs without performing sampling. | |
| Optimized path for prefill-only requests that need token_ids_logprobs but don't | |
| require next token generation. Skips expensive sampling operations | |
| while still providing requested probability information. | |
| Args: | |
| logits_output: The logits output from the model forward | |
| forward_batch: The forward batch that generates logits_output | |
| """ | |
| if not forward_batch.token_ids_logprobs: | |
| return | |
| # Preprocess logits (same as in sample method) | |
| self._preprocess_logits(logits_output, forward_batch.sampling_info) | |
| # Delegate to sampler for logprob-only computation | |
| # This populates logits_output with requested token probabilities | |
| self.sampler.compute_logprobs_only( | |
| logits_output, | |
| forward_batch.sampling_info, | |
| forward_batch.return_logprob, | |
| forward_batch.top_logprobs_nums, | |
| forward_batch.token_ids_logprobs, | |
| ) | |
| def model_is_mrope(self) -> bool: | |
| """Detect if the model has "mrope" rope_scaling type. | |
| mrope requires keep "rope_deltas" between prompt and decoding phases.""" | |
| rope_scaling = getattr(self.model_config.hf_text_config, "rope_scaling", {}) | |
| if rope_scaling is None: | |
| return False | |
| is_mrope_enabled = "mrope_section" in rope_scaling | |
| return is_mrope_enabled | |
| def save_remote_model(self, url: str): | |
| from sglang.srt.model_loader.loader import RemoteModelLoader | |
| logger.info(f"Saving model to {url}") | |
| RemoteModelLoader.save_model(self.model, self.model_config.model_path, url) | |
| def save_sharded_model( | |
| self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None | |
| ): | |
| from sglang.srt.model_loader.loader import ShardedStateLoader | |
| logger.info( | |
| f"Save sharded model to {path} with pattern {pattern} and max_size {max_size}" | |
| ) | |
| ShardedStateLoader.save_model(self.model, path, pattern, max_size) | |
| def update_weights_from_ipc(self, recv_req): | |
| """Update weights from IPC for checkpoint-engine integration.""" | |
| try: | |
| from sglang.srt.checkpoint_engine.checkpoint_engine_worker import ( | |
| SGLangCheckpointEngineWorkerExtensionImpl, | |
| ) | |
| # Create a worker extension that integrates with SGLang's model | |
| worker = SGLangCheckpointEngineWorkerExtensionImpl(self) | |
| worker.update_weights_from_ipc(recv_req.zmq_handles) | |
| return True, "IPC weight update completed successfully" | |
| except ImportError as e: | |
| return False, f"IPC weight update failed: ImportError {e}" | |
| except Exception as e: | |
| logger.error(f"IPC weight update failed: {e}") | |
| return False, str(e) | |
| def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]): | |
| params_dict = dict(model.named_parameters()) | |
| for name, tensor in named_tensors: | |
| default_weight_loader(params_dict[name], tensor) | |
| def _unwrap_tensor(tensor, tp_rank, device): | |
| if isinstance(tensor, LocalSerializedTensor): | |
| tensor = tensor.get(tp_rank) | |
| return tensor.to(device) | |
| class LocalSerializedTensor: | |
| """torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data). | |
| The i-th element in the list corresponds to i-th rank's GPU.""" | |
| values: List[bytes] | |
| def get(self, rank: int): | |
| return MultiprocessingSerializer.deserialize(self.values[rank]) | |
Xet Storage Details
- Size:
- 101 kB
- Xet hash:
- 3d15202132167d1d074caedaaecfd1fb5016a774d08f1b39d4d5b52330d58ede
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.