| """Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py""" | |
| from __future__ import annotations | |
| import itertools | |
| import logging | |
| from typing import TYPE_CHECKING, Dict, List, Optional, Tuple | |
| import torch | |
| from torch.nn.parameter import Parameter, UninitializedParameter | |
| from sglang.srt.distributed import ( | |
| divide, | |
| get_tensor_model_parallel_rank, | |
| get_tensor_model_parallel_world_size, | |
| parallel_state, | |
| split_tensor_along_last_dim, | |
| tensor_model_parallel_all_gather, | |
| tensor_model_parallel_all_reduce, | |
| ) | |
| from sglang.srt.distributed.device_communicators.pynccl_allocator import ( | |
| use_symmetric_memory, | |
| ) | |
| from sglang.srt.layers.parameter import ( | |
| BasevLLMParameter, | |
| BlockQuantScaleParameter, | |
| PackedColumnParameter, | |
| PackedvLLMParameter, | |
| PerTensorScaleParameter, | |
| RowvLLMParameter, | |
| _ColumnvLLMParameter, | |
| ) | |
| from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod | |
| from sglang.srt.layers.utils import pad_or_narrow_weight | |
| from sglang.srt.utils import get_bool_env_var, is_cpu, is_hip, is_npu, set_weight_attrs | |
| if TYPE_CHECKING: | |
| from sglang.srt.layers.quantization.base_config import ( | |
| QuantizationConfig, | |
| QuantizeMethodBase, | |
| ) | |
| _is_hip = is_hip() | |
| _disable_hip_linear_quant = _is_hip and get_bool_env_var( | |
| "SGLANG_ROCM_DISABLE_LINEARQUANT" | |
| ) | |
| logger = logging.getLogger(__name__) | |
| WEIGHT_LOADER_V2_SUPPORTED = [ | |
| "CompressedTensorsLinearMethod", | |
| "AWQMarlinLinearMethod", | |
| "AWQLinearMethod", | |
| "AWQLinearAscendMethod", | |
| "GPTQMarlinLinearMethod", | |
| "Fp8LinearMethod", | |
| "BlockInt8LinearMethod", | |
| "MarlinLinearMethod", | |
| "QQQLinearMethod", | |
| "GPTQMarlin24LinearMethod", | |
| "TPUInt8LinearMethod", | |
| "GPTQLinearMethod", | |
| "FBGEMMFp8LinearMethod", | |
| "ModelOptFp8LinearMethod", | |
| "ModelOptFp4LinearMethod", | |
| "IPEXAWQLinearMethod", | |
| "PetitNvFp4LinearMethod", | |
| ] | |
| _is_cpu = is_cpu() | |
| _is_npu = is_npu() | |
| def adjust_marlin_shard(param, shard_size, shard_offset): | |
| marlin_tile_size = getattr(param, "marlin_tile_size", None) | |
| if marlin_tile_size is None: | |
| return shard_size, shard_offset | |
| return shard_size * marlin_tile_size, shard_offset * marlin_tile_size | |
| def adjust_bitsandbytes_4bit_shard( | |
| param: Parameter, shard_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str | |
| ) -> Tuple[int, int]: | |
| """Adjust the quantization offsets and sizes for BitsAndBytes sharding.""" | |
| total, _ = shard_offsets["total"] | |
| orig_offset, orig_size = shard_offsets[loaded_shard_id] | |
| quantized_total = param.data.shape[0] | |
| quantized_offset = orig_offset * quantized_total // total | |
| quantized_size = orig_size * quantized_total // total | |
| return quantized_size, quantized_offset | |
| def adjust_scalar_to_fused_array(param, loaded_weight, shard_id): | |
| """For fused modules (QKV and MLP) we have an array of length | |
| N that holds 1 scale for each "logical" matrix. So the param | |
| is an array of length N. The loaded_weight corresponds to | |
| one of the shards on disk. Here, we slice the param based on | |
| the shard_id for loading. | |
| """ | |
| qkv_idxs = {"q": 0, "k": 1, "v": 2} | |
| if isinstance(shard_id, str): | |
| shard_id = qkv_idxs[shard_id] | |
| elif not isinstance(shard_id, int): | |
| raise ValueError(f"Unknown Shard Id {shard_id}") | |
| # AutoFP8 scales do not have a shape | |
| # compressed-tensors scales do have a shape | |
| if len(loaded_weight.shape) != 0: | |
| assert loaded_weight.shape[0] == 1 | |
| loaded_weight = loaded_weight[0] | |
| return param[shard_id], loaded_weight | |
| def adjust_shard_offsets(shard_offsets, loaded_weight, dim): | |
| actual_weight_size = loaded_weight.size(dim) | |
| target_weight_size = shard_offsets[-1][-1] + shard_offsets[-1][-2] | |
| if actual_weight_size != target_weight_size: | |
| new_shard_offsets = [] | |
| new_offset = 0 | |
| for shard_id, shard_offset, shard_size in shard_offsets: | |
| actual_shard_size = actual_weight_size * shard_size // target_weight_size | |
| new_shard_offsets.append((shard_id, new_offset, actual_shard_size)) | |
| new_offset += actual_shard_size | |
| return new_shard_offsets | |
| return shard_offsets | |
| class LinearBase(torch.nn.Module): | |
| """Base linear layer. | |
| Args: | |
| input_size: input dimension of the linear layer. | |
| output_size: output dimension of the linear layer. | |
| bias: If true, add bias. | |
| skip_bias_add: If true, skip adding bias but instead return it. | |
| params_dtype: Data type for the parameters. | |
| quant_config: Quantization configure. | |
| """ | |
| def __init__( | |
| self, | |
| input_size: int, | |
| output_size: int, | |
| skip_bias_add: bool = False, | |
| params_dtype: Optional[torch.dtype] = None, | |
| quant_config: Optional[QuantizationConfig] = None, | |
| prefix: str = "", | |
| ): | |
| super().__init__() | |
| # Keep input parameters | |
| self.input_size = input_size | |
| self.output_size = output_size | |
| self.skip_bias_add = skip_bias_add | |
| if params_dtype is None: | |
| params_dtype = torch.get_default_dtype() | |
| self.params_dtype = params_dtype | |
| if quant_config is None: | |
| self.quant_method: Optional[QuantizeMethodBase] = UnquantizedLinearMethod() | |
| else: | |
| self.quant_method = quant_config.get_quant_method(self, prefix=prefix) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| raise NotImplementedError | |
| class ReplicatedLinear(LinearBase): | |
| """Replicated linear layer. | |
| Args: | |
| input_size: input dimension of the linear layer. | |
| output_size: output dimension of the linear layer. | |
| bias: If true, add bias. | |
| skip_bias_add: If true, skip adding bias but instead return it. | |
| params_dtype: Data type for the parameters. | |
| quant_config: Quantization configure. | |
| prefix: The name of the layer in the state dict, including all parents | |
| (e.g. model.layers.0.qkv_proj) | |
| """ | |
| def __init__( | |
| self, | |
| input_size: int, | |
| output_size: int, | |
| bias: bool = True, | |
| skip_bias_add: bool = False, | |
| params_dtype: Optional[torch.dtype] = None, | |
| quant_config: Optional[QuantizationConfig] = None, | |
| prefix: str = "", | |
| ): | |
| super().__init__( | |
| input_size, | |
| output_size, | |
| skip_bias_add, | |
| params_dtype, | |
| quant_config, | |
| prefix=prefix, | |
| ) | |
| # All the linear layer supports quant method. | |
| assert self.quant_method is not None | |
| self.quant_method.create_weights( | |
| self, | |
| self.input_size, | |
| [self.output_size], | |
| self.input_size, | |
| self.output_size, | |
| self.params_dtype, | |
| weight_loader=self.weight_loader, | |
| ) | |
| if bias: | |
| self.bias = Parameter( | |
| torch.empty(self.output_size, dtype=self.params_dtype) | |
| ) | |
| set_weight_attrs( | |
| self.bias, | |
| { | |
| "output_dim": 0, | |
| "weight_loader": self.weight_loader, | |
| }, | |
| ) | |
| else: | |
| self.register_parameter("bias", None) | |
| def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): | |
| # If the weight on disk does not have a shape, give it one | |
| # (such scales for AutoFp8). | |
| if len(loaded_weight.shape) == 0: | |
| loaded_weight = loaded_weight.reshape(1) | |
| # The per-tensor quant-scale must be 1 dimension | |
| if _is_npu: | |
| if param.size() != loaded_weight.size() and param.size(0) == 1: | |
| if torch.allclose(loaded_weight, loaded_weight[0]): | |
| loaded_weight = loaded_weight[:1] | |
| else: | |
| raise ValueError(f"{loaded_weight} are not all equal") | |
| assert param.size() == loaded_weight.size() | |
| param.data.copy_(loaded_weight) | |
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| bias = self.bias if not self.skip_bias_add else None | |
| assert self.quant_method is not None | |
| output = self.quant_method.apply(self, x, bias) | |
| output_bias = self.bias if self.skip_bias_add else None | |
| return output, output_bias | |
| def extra_repr(self) -> str: | |
| s = f"in_features={self.input_size}" | |
| s += f", output_features={self.output_size}" | |
| s += f", bias={self.bias is not None}" | |
| return s | |
| class ColumnParallelLinear(LinearBase): | |
| """Linear layer with column parallelism. | |
| The linear layer is defined as Y = XA + b. A is parallelized along | |
| its second dimension as A = [A_1, ..., A_p]. | |
| Args: | |
| input_size: first dimension of matrix A. | |
| output_size: second dimension of matrix A. | |
| bias: If true, add bias. | |
| gather_output: If true, call all-gather on output and make Y available | |
| to all GPUs, otherwise, every GPU will have its output | |
| which is Y_i = XA_i | |
| skip_bias_add: This was added to enable performance optimizations where | |
| bias can be fused with other element-wise operations. we | |
| skip adding bias but instead return it. | |
| params_dtype: Data type for the parameters. | |
| quant_config: Quantization configure. | |
| output_sizes: list of output sizes packed into one output, like for QKV | |
| the list would be size 3. | |
| prefix: The name of the layer in the state dict, including all parents | |
| (e.g. model.layers.0.qkv_proj) | |
| """ | |
| def __init__( | |
| self, | |
| input_size: int, | |
| output_size: int, | |
| bias: bool = True, | |
| gather_output: bool = False, | |
| skip_bias_add: bool = False, | |
| params_dtype: Optional[torch.dtype] = None, | |
| quant_config: Optional[QuantizationConfig] = None, | |
| output_sizes: Optional[List[int]] = None, | |
| prefix: str = "", | |
| tp_rank: Optional[int] = None, | |
| tp_size: Optional[int] = None, | |
| use_presharded_weights: bool = False, | |
| ): | |
| super().__init__( | |
| input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix | |
| ) | |
| self.gather_output = gather_output | |
| self.use_presharded_weights = use_presharded_weights | |
| # Divide the weight matrix along the last dimension. | |
| if tp_rank is None: | |
| tp_rank = get_tensor_model_parallel_rank() | |
| if tp_size is None: | |
| tp_size = get_tensor_model_parallel_world_size() | |
| self.tp_rank, self.tp_size = tp_rank, tp_size | |
| assert self.quant_method is not None | |
| self.output_size_per_partition = divide(self.output_size, tp_size) | |
| self.output_partition_sizes = [self.output_size_per_partition] | |
| # If QKV or MergedColumn, use output size of each partition. | |
| if hasattr(self, "output_sizes"): | |
| self.output_partition_sizes = [ | |
| divide(output_size, tp_size) for output_size in self.output_sizes | |
| ] | |
| if output_sizes is None: | |
| output_sizes = [output_size] | |
| self.quant_method.create_weights( | |
| layer=self, | |
| input_size_per_partition=self.input_size, | |
| output_partition_sizes=self.output_partition_sizes, | |
| input_size=self.input_size, | |
| output_size=self.output_size, | |
| params_dtype=self.params_dtype, | |
| weight_loader=( | |
| self.weight_loader_v2 | |
| if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED | |
| else self.weight_loader | |
| ), | |
| ) | |
| if bias: | |
| self.bias = Parameter( | |
| torch.empty(self.output_size_per_partition, dtype=params_dtype) | |
| ) | |
| set_weight_attrs( | |
| self.bias, | |
| { | |
| "output_dim": 0, | |
| "weight_loader": self.weight_loader, | |
| }, | |
| ) | |
| else: | |
| self.register_parameter("bias", None) | |
| def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): | |
| output_dim = getattr(param, "output_dim", None) | |
| # Special case for GGUF | |
| is_gguf_weight = getattr(param, "is_gguf_weight", False) | |
| is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) | |
| if is_gguf_weight_type: | |
| param.weight_type = loaded_weight.item() | |
| # Materialize GGUF UninitializedParameter | |
| if is_gguf_weight and isinstance(param, UninitializedParameter): | |
| param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype) | |
| use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) | |
| param_data = param.data | |
| # bitsandbytes loads the weights of the specific portion | |
| # no need to narrow here | |
| if output_dim is not None and not use_bitsandbytes_4bit: | |
| shard_size = param_data.shape[output_dim] | |
| start_idx = self.tp_rank * shard_size | |
| if _is_cpu: | |
| from sglang.srt.model_loader.weight_utils import ( | |
| narrow_padded_param_and_loaded_weight, | |
| ) | |
| param_data, loaded_weight = narrow_padded_param_and_loaded_weight( | |
| param_data, | |
| loaded_weight, | |
| 0, # param_data_start | |
| start_idx, | |
| output_dim, | |
| shard_size, | |
| not self.use_presharded_weights, | |
| ) | |
| else: | |
| if not self.use_presharded_weights: | |
| loaded_weight = loaded_weight.narrow( | |
| output_dim, start_idx, shard_size | |
| ) | |
| # Special case for loading scales off disk, which often do not | |
| # have a shape (such as in the case of AutoFP8). | |
| if len(loaded_weight.shape) == 0: | |
| loaded_weight = loaded_weight.reshape(1) | |
| assert param_data.shape == loaded_weight.shape | |
| param_data.copy_(loaded_weight) | |
| def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor): | |
| # Special case for loading scales off disk, which often do not | |
| # have a shape (such as in the case of AutoFP8). | |
| if len(loaded_weight.shape) == 0: | |
| assert loaded_weight.numel() == 1 | |
| loaded_weight = loaded_weight.reshape(1) | |
| if isinstance(param, _ColumnvLLMParameter): | |
| param.load_column_parallel_weight( | |
| loaded_weight, | |
| tp_rank=self.tp_rank, | |
| use_presharded_weights=self.use_presharded_weights, | |
| ) | |
| else: | |
| # FIXME: This branch is needed to load deepseek v3 awq. | |
| # However, we should fix this and avoid the branching here. | |
| param.load_column_parallel_weight(loaded_weight) | |
| def forward(self, input_): | |
| bias = self.bias if not self.skip_bias_add else None | |
| # Matrix multiply. | |
| assert self.quant_method is not None | |
| output_parallel = self.quant_method.apply(self, input_, bias) | |
| if self.gather_output: | |
| # All-gather across the partitions. | |
| output = tensor_model_parallel_all_gather(output_parallel) | |
| else: | |
| output = output_parallel | |
| output_bias = self.bias if self.skip_bias_add else None | |
| return output, output_bias | |
| def extra_repr(self) -> str: | |
| s = f"in_features={self.input_size}" | |
| s += f", output_features={self.output_size_per_partition}" | |
| s += f", bias={self.bias is not None}" | |
| s += f", tp_size={self.tp_size}" | |
| s += f", gather_output={self.gather_output}" | |
| return s | |
| class MergedColumnParallelLinear(ColumnParallelLinear): | |
| """Packed linear layers with column parallelism. | |
| Similar to ColumnParallelLinear, but the weight matrix is concatenated | |
| along the output dimension. When the weight matrix is loaded, the | |
| different partitions are sharded separately. | |
| Args: | |
| input_size: input dimension of the linear layer. | |
| output_sizes: list of output dimensions of the linear layer. | |
| bias: If true, add bias. | |
| gather_output: If true, call all-gather on output and make the output | |
| available to all GPUs, otherwise, every GPU will have | |
| its own output. | |
| skip_bias_add: This was added to enable performance optimizations where | |
| bias can be fused with other element-wise operations. we | |
| skip adding bias but instead return it. | |
| params_dtype: Data type for the parameters. | |
| quant_config: Quantization configure. | |
| prefix: The name of the layer in the state dict, including all parents | |
| (e.g. model.layers.0.qkv_proj) | |
| """ | |
| def __init__( | |
| self, | |
| input_size: int, | |
| output_sizes: List[int], | |
| bias: bool = True, | |
| gather_output: bool = False, | |
| skip_bias_add: bool = False, | |
| params_dtype: Optional[torch.dtype] = None, | |
| quant_config: Optional[QuantizationConfig] = None, | |
| prefix: str = "", | |
| tp_rank: Optional[int] = None, | |
| tp_size: Optional[int] = None, | |
| use_presharded_weights: bool = False, | |
| ): | |
| self.output_sizes = output_sizes | |
| if tp_rank is None: | |
| tp_rank = get_tensor_model_parallel_rank() | |
| if tp_size is None: | |
| tp_size = get_tensor_model_parallel_world_size() | |
| self.tp_rank, self.tp_size = tp_rank, tp_size | |
| assert all(output_size % tp_size == 0 for output_size in output_sizes) | |
| self.use_presharded_weights = use_presharded_weights | |
| super().__init__( | |
| input_size=input_size, | |
| output_size=sum(output_sizes), | |
| bias=bias, | |
| gather_output=gather_output, | |
| skip_bias_add=skip_bias_add, | |
| params_dtype=params_dtype, | |
| quant_config=quant_config, | |
| prefix=prefix, | |
| tp_rank=tp_rank, | |
| tp_size=tp_size, | |
| use_presharded_weights=use_presharded_weights, | |
| ) | |
| self.prefix = prefix | |
| def weight_loader( | |
| self, | |
| param: Parameter, | |
| loaded_weight: torch.Tensor, | |
| loaded_shard_id: Optional[int] = None, | |
| ): | |
| # Special case for GGUF | |
| # initialize GGUF param after we know the quantize type | |
| is_gguf_weight = getattr(param, "is_gguf_weight", False) | |
| is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) | |
| if is_gguf_weight_type: | |
| param.data[loaded_shard_id].copy_(loaded_weight) | |
| param.shard_weight_type[loaded_shard_id] = loaded_weight.item() | |
| return | |
| if is_gguf_weight: | |
| output_dim = getattr(param, "output_dim", None) | |
| shard_size = loaded_weight.size(output_dim) // self.tp_size | |
| start_idx = self.tp_rank * shard_size | |
| loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) | |
| param.shard_id.append(loaded_shard_id) | |
| param.shard_id_map[loaded_shard_id] = len(param.data_container) | |
| param.data_container.append(loaded_weight) | |
| return | |
| param_data = param.data | |
| output_dim = getattr(param, "output_dim", None) | |
| # Special case for AQLM codebooks. | |
| is_metadata = getattr(param, "is_metadata", False) | |
| # Special case for per-tensor scale to load scalar into fused array. | |
| needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) | |
| if loaded_shard_id is None: | |
| # Loaded weight is already fused on disk (qkv/mlp). | |
| if output_dim is None: | |
| if needs_scalar_to_array: | |
| param_data, loaded_weight = adjust_scalar_to_fused_array( | |
| param_data, loaded_weight, 0 | |
| ) | |
| assert param_data.shape == loaded_weight.shape | |
| param_data.copy_(loaded_weight) | |
| return | |
| current_shard_offset = 0 | |
| shard_offsets: List[Tuple[int, int, int]] = [] | |
| for i, output_size in enumerate(self.output_sizes): | |
| shard_offsets.append((i, current_shard_offset, output_size)) | |
| current_shard_offset += output_size | |
| packed_dim = getattr(param, "packed_dim", None) | |
| use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) | |
| if _is_cpu: | |
| shard_offsets = adjust_shard_offsets( | |
| shard_offsets, loaded_weight, output_dim | |
| ) | |
| for shard_id, shard_offset, shard_size in shard_offsets: | |
| # Special case for Quantization. | |
| # If quantized, we need to adjust the offset and size to account | |
| # for the packing. | |
| if packed_dim == output_dim: | |
| shard_size = shard_size // param.pack_factor | |
| shard_offset = shard_offset // param.pack_factor | |
| # Special case for Marlin. | |
| shard_size, shard_offset = adjust_marlin_shard( | |
| param, shard_size, shard_offset | |
| ) | |
| if use_bitsandbytes_4bit: | |
| index = list(itertools.accumulate([0] + self.output_sizes)) | |
| orig_offsets = { | |
| str(i): (index[i], size) | |
| for i, size in enumerate(self.output_sizes) | |
| } | |
| orig_offsets["total"] = (self.output_size, 0) | |
| shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( | |
| param, orig_offsets, str(shard_id) | |
| ) | |
| loaded_weight_shard = loaded_weight.narrow( | |
| output_dim, shard_offset, shard_size | |
| ) | |
| self.weight_loader(param, loaded_weight_shard, shard_id) | |
| return | |
| assert loaded_shard_id < len(self.output_sizes) | |
| if output_dim is not None: | |
| shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size | |
| shard_size = self.output_sizes[loaded_shard_id] // self.tp_size | |
| # Special case for quantization. | |
| # If quantized, we need to adjust the offset and size to account | |
| # for the packing. | |
| packed_dim = getattr(param, "packed_dim", None) | |
| if packed_dim == output_dim: | |
| shard_size = shard_size // param.pack_factor | |
| shard_offset = shard_offset // param.pack_factor | |
| # Special case for Marlin. | |
| shard_size, shard_offset = adjust_marlin_shard( | |
| param, shard_size, shard_offset | |
| ) | |
| use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) | |
| if use_bitsandbytes_4bit: | |
| shard_size = loaded_weight.shape[output_dim] | |
| shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id | |
| param_data = param_data.narrow(output_dim, shard_offset, shard_size) | |
| start_idx = self.tp_rank * shard_size | |
| if _is_cpu: | |
| from sglang.srt.model_loader.weight_utils import ( | |
| narrow_padded_param_and_loaded_weight, | |
| ) | |
| param_data, loaded_weight = narrow_padded_param_and_loaded_weight( | |
| param_data, | |
| loaded_weight, | |
| 0, # param_data_start | |
| start_idx, | |
| output_dim, | |
| shard_size, | |
| not use_bitsandbytes_4bit and not self.use_presharded_weights, | |
| ) | |
| else: | |
| # bitsandbytes loads the weights of the specific portion | |
| # no need to narrow here | |
| if not use_bitsandbytes_4bit and not self.use_presharded_weights: | |
| # Padding for special case like qwen2_5_VL's mlp which is not 8-aligned | |
| end_idx = start_idx + shard_size | |
| if end_idx > loaded_weight.shape[output_dim]: | |
| loaded_weight = pad_or_narrow_weight( | |
| loaded_weight, output_dim, start_idx, shard_size | |
| ) | |
| else: | |
| loaded_weight = loaded_weight.narrow( | |
| output_dim, start_idx, shard_size | |
| ) | |
| # Special case for AQLM codebooks. | |
| elif is_metadata: | |
| # metadata indicates fixed size concatenated along dim 0 | |
| shard_size = loaded_weight.shape[0] | |
| shard_offset = loaded_shard_id * shard_size | |
| param_data = param_data.narrow(0, shard_offset, shard_size) | |
| # Special case for per-tensor scales in fused case. | |
| elif needs_scalar_to_array: | |
| param_data, loaded_weight = adjust_scalar_to_fused_array( | |
| param_data, loaded_weight, loaded_shard_id | |
| ) | |
| else: | |
| ignore_warning = getattr(param, "ignore_warning", False) | |
| if not ignore_warning: | |
| logger.warning( | |
| "Loading a weight without `output_dim` attribute in " | |
| "MergedColumnParallelLinear, assume the weight is " | |
| "the same for all partitions." | |
| ) | |
| assert param_data.shape == loaded_weight.shape | |
| param_data.copy_(loaded_weight) | |
| def _load_fused_module_from_checkpoint( | |
| self, param: BasevLLMParameter, loaded_weight: torch.Tensor | |
| ): | |
| """ | |
| Handle special case for models where MLP layers are already | |
| fused on disk. In this case, we have no shard id. This function | |
| determmines the shard id by splitting these layers and then calls | |
| the weight loader using the shard id. | |
| An example of a model with these fused layers: | |
| https://huggingface.co/microsoft/Phi-3-mini-4k-instruct | |
| """ | |
| current_shard_offset = 0 | |
| shard_offsets: List[Tuple[int, int, int]] = [] | |
| for i, output_size in enumerate(self.output_sizes): | |
| shard_offsets.append((i, current_shard_offset, output_size)) | |
| current_shard_offset += output_size | |
| for shard_id, shard_offset, shard_size in shard_offsets: | |
| # Special case for Quantization. | |
| # If quantized, we need to adjust the offset and size to account | |
| # for the packing. | |
| if ( | |
| isinstance(param, (PackedColumnParameter, PackedvLLMParameter)) | |
| and param.packed_dim == param.output_dim | |
| ): | |
| shard_size, shard_offset = param.adjust_shard_indexes_for_packing( | |
| shard_size=shard_size, shard_offset=shard_offset | |
| ) | |
| loaded_weight_shard = loaded_weight.narrow( | |
| param.output_dim, shard_offset, shard_size | |
| ) | |
| self.weight_loader_v2(param, loaded_weight_shard, shard_id) | |
| def weight_loader_v2( | |
| self, | |
| param: BasevLLMParameter, | |
| loaded_weight: torch.Tensor, | |
| loaded_shard_id: Optional[int] = None, | |
| ): | |
| if loaded_shard_id is None: | |
| if isinstance(param, PerTensorScaleParameter): | |
| param.load_merged_column_weight( | |
| loaded_weight=loaded_weight, | |
| shard_id=0, | |
| tp_rank=self.tp_rank, | |
| tp_size=self.tp_size, | |
| ) | |
| return | |
| elif type(param) in (RowvLLMParameter, BasevLLMParameter): | |
| param.load_merged_column_weight( | |
| loaded_weight=loaded_weight, | |
| tp_rank=self.tp_rank, | |
| tp_size=self.tp_size, | |
| ) | |
| return | |
| # TODO: @dsikka - move to parameter.py | |
| self._load_fused_module_from_checkpoint(param, loaded_weight) | |
| return | |
| assert loaded_shard_id < len(self.output_sizes) | |
| if isinstance(param, BlockQuantScaleParameter): | |
| weight_block_size = self.quant_method.quant_config.weight_block_size | |
| block_n, _ = weight_block_size[0], weight_block_size[1] | |
| shard_offset = ( | |
| (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n | |
| ) // self.tp_size | |
| shard_size = ( | |
| (self.output_sizes[loaded_shard_id] + block_n - 1) | |
| // block_n | |
| // self.tp_size | |
| ) | |
| else: | |
| shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size | |
| shard_size = self.output_sizes[loaded_shard_id] // self.tp_size | |
| param.load_merged_column_weight( | |
| loaded_weight=loaded_weight, | |
| shard_id=loaded_shard_id, | |
| shard_offset=shard_offset, | |
| shard_size=shard_size, | |
| use_presharded_weights=self.use_presharded_weights, | |
| tp_rank=self.tp_rank, | |
| tp_size=self.tp_size, | |
| ) | |
| class QKVParallelLinear(ColumnParallelLinear): | |
| """Linear layers for the attention's QKV transformation. | |
| Linear layers for the linear transformation of the query, key, and value | |
| vectors in the attention layer. The weight matrix is concatenated along | |
| the output dimension. The layer is parallelized along the head dimension. | |
| When the number of key/value heads is smaller than the number of query | |
| heads (e.g., multi-query/grouped-query attention), the key/value head may | |
| be replicated while the query heads are partitioned. | |
| Args: | |
| hidden_size: input hidden state size of the transformer. | |
| head_size: size of each attention head. | |
| total_num_heads: total number of attention query heads. | |
| total_num_kv_heads: total number of attention key/value heads. If | |
| None, assume total_num_kv_heads = total_num_heads. | |
| bias: If true, add bias. | |
| skip_bias_add: This was added to enable performance optimizations where | |
| bias can be fused with other element-wise operations. we | |
| skip adding bias but instead return it. | |
| params_dtype: Data type for the parameters. | |
| quant_config: Quantization configure. | |
| prefix: The name of the layer in the state dict, including all parents | |
| (e.g. model.layers.0.qkv_proj) | |
| """ | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| head_size: int, | |
| total_num_heads: int, | |
| total_num_kv_heads: Optional[int] = None, | |
| bias: bool = True, | |
| skip_bias_add: bool = False, | |
| params_dtype: Optional[torch.dtype] = None, | |
| quant_config: Optional[QuantizationConfig] = None, | |
| prefix: str = "", | |
| tp_rank: Optional[int] = None, | |
| tp_size: Optional[int] = None, | |
| load_presharded_attn: bool = False, | |
| ): | |
| self.hidden_size = hidden_size | |
| self.head_size = head_size | |
| self.total_num_heads = total_num_heads | |
| if total_num_kv_heads is None: | |
| total_num_kv_heads = total_num_heads | |
| self.total_num_kv_heads = total_num_kv_heads | |
| # Divide the weight matrix along the last dimension. | |
| if tp_rank is None: | |
| tp_rank = get_tensor_model_parallel_rank() | |
| if tp_size is None: | |
| tp_size = get_tensor_model_parallel_world_size() | |
| self.tp_rank, self.tp_size = tp_rank, tp_size | |
| self.num_heads = divide(self.total_num_heads, tp_size) | |
| if tp_size >= self.total_num_kv_heads: | |
| self.num_kv_heads = 1 | |
| self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads) | |
| else: | |
| self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) | |
| self.num_kv_head_replicas = 1 | |
| self.q_proj_shard_size = self.num_heads * self.head_size | |
| self.kv_proj_shard_size = self.num_kv_heads * self.head_size | |
| input_size = self.hidden_size | |
| output_size = ( | |
| (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size | |
| ) | |
| self.output_sizes = [ | |
| self.num_heads * self.head_size * tp_size, # q_proj | |
| self.num_kv_heads * self.head_size * tp_size, # k_proj | |
| self.num_kv_heads * self.head_size * tp_size, # v_proj | |
| ] | |
| self.use_presharded_weights = load_presharded_attn | |
| quant_config = None if _disable_hip_linear_quant else quant_config | |
| super().__init__( | |
| input_size=input_size, | |
| output_size=output_size, | |
| bias=bias, | |
| gather_output=False, | |
| skip_bias_add=skip_bias_add, | |
| params_dtype=params_dtype, | |
| quant_config=quant_config, | |
| prefix=prefix, | |
| tp_rank=tp_rank, | |
| tp_size=tp_size, | |
| use_presharded_weights=self.use_presharded_weights, | |
| ) | |
| def _get_shard_offset_mapping(self, loaded_shard_id: str): | |
| shard_offset_mapping = { | |
| "q": 0, | |
| "k": self.num_heads * self.head_size, | |
| "v": (self.num_heads + self.num_kv_heads) * self.head_size, | |
| "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size, | |
| } | |
| return shard_offset_mapping.get(loaded_shard_id) | |
| def _get_shard_size_mapping(self, loaded_shard_id: str): | |
| shard_size_mapping = { | |
| "q": self.num_heads * self.head_size, | |
| "k": self.num_kv_heads * self.head_size, | |
| "v": self.num_kv_heads * self.head_size, | |
| } | |
| return shard_size_mapping.get(loaded_shard_id) | |
| def _load_fused_module_from_checkpoint( | |
| self, param: BasevLLMParameter, loaded_weight: torch.Tensor | |
| ): | |
| """ | |
| Handle special case for models where QKV layers are already | |
| fused on disk. In this case, we have no shard id. This function | |
| determmines the shard id by splitting these layers and then calls | |
| the weight loader using the shard id. | |
| An example of a model with these fused layers: | |
| https://huggingface.co/microsoft/Phi-3-mini-4k-instruct | |
| """ | |
| shard_offsets = [ | |
| # (shard_id, shard_offset, shard_size) | |
| ("q", 0, self.total_num_heads * self.head_size), | |
| ( | |
| "k", | |
| self.total_num_heads * self.head_size, | |
| self.total_num_kv_heads * self.head_size, | |
| ), | |
| ( | |
| "v", | |
| (self.total_num_heads + self.total_num_kv_heads) * self.head_size, | |
| self.total_num_kv_heads * self.head_size, | |
| ), | |
| ] | |
| for shard_id, shard_offset, shard_size in shard_offsets: | |
| # Special case for Quantization. | |
| # If quantized, we need to adjust the offset and size to account | |
| # for the packing. | |
| if ( | |
| isinstance(param, (PackedColumnParameter, PackedvLLMParameter)) | |
| and param.packed_dim == param.output_dim | |
| ): | |
| shard_size, shard_offset = param.adjust_shard_indexes_for_packing( | |
| shard_size=shard_size, shard_offset=shard_offset | |
| ) | |
| if not self.use_presharded_weights: | |
| loaded_weight_shard = loaded_weight.narrow( | |
| param.output_dim, shard_offset, shard_size | |
| ) | |
| self.weight_loader_v2(param, loaded_weight_shard, shard_id) | |
| def _load_qkv_block_scale( | |
| self, param: BasevLLMParameter, loaded_weight: torch.Tensor | |
| ): | |
| block_n, _ = self.quant_method.quant_config.weight_block_size | |
| q_size = self.total_num_heads * self.head_size // block_n | |
| k_size = self.total_num_kv_heads * self.head_size // block_n | |
| v_size = self.total_num_kv_heads * self.head_size // block_n | |
| shard_offsets = [ | |
| # (shard_id, shard_offset, shard_size) | |
| ("q", 0, q_size), | |
| ("k", q_size, k_size), | |
| ("v", q_size + k_size, v_size), | |
| ] | |
| for shard_id, shard_offset, shard_size in shard_offsets: | |
| loaded_weight_shard = loaded_weight.narrow( | |
| param.output_dim, shard_offset, shard_size | |
| ) | |
| rank_shard_offset = self._get_shard_offset_mapping(shard_id) // block_n | |
| rank_shard_size = self._get_shard_size_mapping(shard_id) // block_n | |
| param.load_qkv_weight( | |
| loaded_weight=loaded_weight_shard, | |
| num_heads=self.num_kv_head_replicas, | |
| shard_id=shard_id, | |
| shard_offset=rank_shard_offset, | |
| shard_size=rank_shard_size, | |
| tp_rank=self.tp_rank, | |
| use_presharded_weights=self.use_presharded_weights, | |
| ) | |
| def weight_loader_v2( | |
| self, | |
| param: BasevLLMParameter, | |
| loaded_weight: torch.Tensor, | |
| loaded_shard_id: Optional[str] = None, | |
| ): | |
| if loaded_shard_id is None: # special case for certain models | |
| if isinstance(param, PerTensorScaleParameter): | |
| param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0) | |
| return | |
| elif type(param) in (RowvLLMParameter, BasevLLMParameter): | |
| param.load_qkv_weight(loaded_weight=loaded_weight) | |
| return | |
| elif isinstance(param, BlockQuantScaleParameter): | |
| self._load_qkv_block_scale(param, loaded_weight) | |
| return | |
| # TODO: @dsikka - move to parameter.py | |
| self._load_fused_module_from_checkpoint(param, loaded_weight) | |
| return | |
| assert loaded_shard_id in ["q", "k", "v"] | |
| shard_offset = self._get_shard_offset_mapping(loaded_shard_id) | |
| shard_size = self._get_shard_size_mapping(loaded_shard_id) | |
| if isinstance(param, BlockQuantScaleParameter): | |
| weight_block_size = self.quant_method.quant_config.weight_block_size | |
| block_n, _ = weight_block_size[0], weight_block_size[1] | |
| shard_offset = (shard_offset + block_n - 1) // block_n | |
| shard_size = (shard_size + block_n - 1) // block_n | |
| param.load_qkv_weight( | |
| loaded_weight=loaded_weight, | |
| num_heads=self.num_kv_head_replicas, | |
| shard_id=loaded_shard_id, | |
| shard_offset=shard_offset, | |
| shard_size=shard_size, | |
| tp_rank=self.tp_rank, | |
| use_presharded_weights=self.use_presharded_weights, | |
| ) | |
| def weight_loader( | |
| self, | |
| param: Parameter, | |
| loaded_weight: torch.Tensor, | |
| loaded_shard_id: Optional[str] = None, | |
| ): | |
| # Special case for GGUF | |
| # initialize GGUF param after we know the quantize type | |
| is_gguf_weight = getattr(param, "is_gguf_weight", False) | |
| is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) | |
| if is_gguf_weight_type and loaded_shard_id is not None: | |
| idx_map = {"q": 0, "k": 1, "v": 2} | |
| param.data[idx_map[loaded_shard_id]].copy_(loaded_weight) | |
| param.shard_weight_type[loaded_shard_id] = loaded_weight.item() | |
| return | |
| if is_gguf_weight: | |
| output_dim = getattr(param, "output_dim", None) | |
| shard_size = loaded_weight.size(output_dim) // self.tp_size | |
| start_idx = self.tp_rank * shard_size | |
| loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) | |
| param.shard_id.append(loaded_shard_id) | |
| param.shard_id_map[loaded_shard_id] = len(param.data_container) | |
| param.data_container.append(loaded_weight) | |
| return | |
| param_data = param.data | |
| output_dim = getattr(param, "output_dim", None) | |
| # Special case for AQLM codebooks. | |
| is_metadata = getattr(param, "is_metadata", False) | |
| # Special case for per-tensor scales in fused case. | |
| needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) | |
| if loaded_shard_id is None: | |
| # Loaded weight is already fused on disk (qkv/mlp). | |
| if output_dim is None: | |
| if needs_scalar_to_array: | |
| param_data, loaded_weight = adjust_scalar_to_fused_array( | |
| param_data, loaded_weight, 0 | |
| ) | |
| assert param_data.shape == loaded_weight.shape | |
| param_data.copy_(loaded_weight) | |
| return | |
| shard_offsets = [ | |
| # (shard_id, shard_offset, shard_size) | |
| ("q", 0, self.total_num_heads * self.head_size), | |
| ( | |
| "k", | |
| self.total_num_heads * self.head_size, | |
| self.total_num_kv_heads * self.head_size, | |
| ), | |
| ( | |
| "v", | |
| (self.total_num_heads + self.total_num_kv_heads) * self.head_size, | |
| self.total_num_kv_heads * self.head_size, | |
| ), | |
| ] | |
| use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) | |
| packed_dim = getattr(param, "packed_dim", None) | |
| if _is_cpu: | |
| shard_offsets = adjust_shard_offsets( | |
| shard_offsets, loaded_weight, output_dim | |
| ) | |
| for shard_id, shard_offset, shard_size in shard_offsets: | |
| # Special case for Quantized Weights. | |
| # If quantized, we need to adjust the offset and size to account | |
| # for the packing. | |
| if packed_dim == output_dim: | |
| shard_size = shard_size // param.pack_factor | |
| shard_offset = shard_offset // param.pack_factor | |
| # Special case for Marlin. | |
| shard_size, shard_offset = adjust_marlin_shard( | |
| param, shard_size, shard_offset | |
| ) | |
| if use_bitsandbytes_4bit: | |
| orig_qkv_offsets = { | |
| "q": (0, self.total_num_heads * self.head_size), | |
| "k": ( | |
| self.total_num_heads * self.head_size, | |
| self.total_num_kv_heads * self.head_size, | |
| ), | |
| "v": ( | |
| (self.total_num_heads + self.total_num_kv_heads) | |
| * self.head_size, | |
| self.total_num_kv_heads * self.head_size, | |
| ), | |
| "total": ( | |
| (self.total_num_heads + 2 * self.total_num_kv_heads) | |
| * self.head_size, | |
| 0, | |
| ), | |
| } | |
| shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( | |
| param, orig_qkv_offsets, shard_id | |
| ) | |
| if not self.use_presharded_weights: | |
| loaded_weight_shard = loaded_weight.narrow( | |
| output_dim, shard_offset, shard_size | |
| ) | |
| self.weight_loader(param, loaded_weight_shard, shard_id) | |
| return | |
| assert loaded_shard_id in ["q", "k", "v"] | |
| # If output dim is defined, use the default loading process. | |
| if output_dim is not None: | |
| if loaded_shard_id == "q": | |
| shard_offset = 0 | |
| shard_size = self.num_heads * self.head_size | |
| elif loaded_shard_id == "k": | |
| shard_offset = self.num_heads * self.head_size | |
| shard_size = self.num_kv_heads * self.head_size | |
| elif loaded_shard_id == "v": | |
| shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size | |
| shard_size = self.num_kv_heads * self.head_size | |
| # Special case for Quantized Weights. | |
| # If quantized, we need to adjust the offset and size to account | |
| # for the packing. | |
| packed_dim = getattr(param, "packed_dim", None) | |
| if packed_dim == output_dim: | |
| shard_size = shard_size // param.pack_factor | |
| shard_offset = shard_offset // param.pack_factor | |
| # Special case for Marlin. | |
| shard_size, shard_offset = adjust_marlin_shard( | |
| param, shard_size, shard_offset | |
| ) | |
| use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) | |
| if use_bitsandbytes_4bit: | |
| orig_qkv_offsets = { | |
| "q": (0, self.num_heads * self.head_size), | |
| "k": ( | |
| self.num_heads * self.head_size, | |
| self.num_kv_heads * self.head_size, | |
| ), | |
| "v": ( | |
| (self.num_heads + self.num_kv_heads) * self.head_size, | |
| self.num_kv_heads * self.head_size, | |
| ), | |
| "total": ( | |
| (self.num_heads + 2 * self.num_kv_heads) * self.head_size, | |
| 0, | |
| ), | |
| } | |
| shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( | |
| param, orig_qkv_offsets, loaded_shard_id | |
| ) | |
| param_data = param_data.narrow(output_dim, shard_offset, shard_size) | |
| if loaded_shard_id == "q": | |
| shard_id = self.tp_rank | |
| else: | |
| shard_id = self.tp_rank // self.num_kv_head_replicas | |
| start_idx = shard_id * shard_size | |
| if _is_cpu: | |
| from sglang.srt.model_loader.weight_utils import ( | |
| narrow_padded_param_and_loaded_weight, | |
| ) | |
| param_data, loaded_weight = narrow_padded_param_and_loaded_weight( | |
| param_data, | |
| loaded_weight, | |
| 0, # param_data_start | |
| start_idx, | |
| output_dim, | |
| shard_size, | |
| not use_bitsandbytes_4bit and not self.use_presharded_weights, | |
| ) | |
| else: | |
| # bitsandbytes loads the weights of the specific portion | |
| # no need to narrow here | |
| if not use_bitsandbytes_4bit and not self.use_presharded_weights: | |
| loaded_weight = loaded_weight.narrow( | |
| output_dim, start_idx, shard_size | |
| ) | |
| # Special case for for AQLM codebooks. | |
| elif is_metadata: | |
| # metadata indicates fixed size concatenated along dim 0 | |
| shard_size = loaded_weight.shape[0] | |
| shard_index = ["q", "k", "v"].index(loaded_shard_id) | |
| param_data = param_data.narrow(0, shard_index * shard_size, shard_size) | |
| # Special case for per-tensor scales in fused case. | |
| elif needs_scalar_to_array: | |
| param_data, loaded_weight = adjust_scalar_to_fused_array( | |
| param_data, loaded_weight, loaded_shard_id | |
| ) | |
| else: | |
| ignore_warning = getattr(param, "ignore_warning", False) | |
| if not ignore_warning: | |
| logger.warning( | |
| "Loading a weight without `output_dim` attribute in " | |
| "QKVParallelLinear, assume the weight is the same " | |
| "for all partitions." | |
| ) | |
| assert param_data.shape == loaded_weight.shape | |
| param_data.copy_(loaded_weight) | |
| class RowParallelLinear(LinearBase): | |
| """Linear layer with row parallelism. | |
| The linear layer is defined as Y = XA + b. A is parallelized along | |
| its first dimension and X along its second dimension as: | |
| - - | |
| | A_1 | | |
| | . | | |
| A = | . | X = [X_1, ..., X_p] | |
| | . | | |
| | A_p | | |
| - - | |
| Arguments: | |
| input_size: first dimension of matrix A. | |
| output_size: second dimension of matrix A. | |
| bias: If true, add bias. Note that bias is not parallelized. | |
| input_is_parallel: If true, we assume that the input is already | |
| split across the GPUs and we do not split | |
| again. | |
| skip_bias_add: This was added to enable performance optimization where | |
| bias can be fused with other element-wise operations. | |
| We skip adding bias but instead return it. | |
| params_dtype: Data type for the parameters. | |
| quant_config: Quantization configure. | |
| """ | |
| def __init__( | |
| self, | |
| input_size: int, | |
| output_size: int, | |
| bias: bool = True, | |
| input_is_parallel: bool = True, | |
| skip_bias_add: bool = False, | |
| params_dtype: Optional[torch.dtype] = None, | |
| reduce_results: bool = True, | |
| quant_config: Optional[QuantizationConfig] = None, | |
| prefix: str = "", | |
| tp_rank: Optional[int] = None, | |
| tp_size: Optional[int] = None, | |
| use_presharded_weights: bool = False, | |
| ): | |
| quant_config = None if _disable_hip_linear_quant else quant_config | |
| super().__init__( | |
| input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix | |
| ) | |
| self.input_is_parallel = input_is_parallel | |
| self.reduce_results = reduce_results | |
| # Divide the weight matrix along the last dimension. | |
| if tp_rank is None: | |
| tp_rank = get_tensor_model_parallel_rank() | |
| if tp_size is None: | |
| tp_size = get_tensor_model_parallel_world_size() | |
| self.tp_rank, self.tp_size = tp_rank, tp_size | |
| self.input_size_per_partition = divide(input_size, self.tp_size) | |
| assert self.quant_method is not None | |
| self.use_presharded_weights = use_presharded_weights | |
| self.quant_method.create_weights( | |
| layer=self, | |
| input_size_per_partition=self.input_size_per_partition, | |
| output_partition_sizes=[self.output_size], | |
| input_size=self.input_size, | |
| output_size=self.output_size, | |
| params_dtype=self.params_dtype, | |
| weight_loader=( | |
| self.weight_loader_v2 | |
| if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED | |
| else self.weight_loader | |
| ), | |
| ) | |
| if bias: | |
| self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype)) | |
| set_weight_attrs( | |
| self.bias, | |
| { | |
| "output_dim": 0, | |
| "weight_loader": self.weight_loader, | |
| }, | |
| ) | |
| else: | |
| self.register_parameter("bias", None) | |
| def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): | |
| input_dim = getattr(param, "input_dim", None) | |
| use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) | |
| # Special case for GGUF | |
| is_gguf_weight = getattr(param, "is_gguf_weight", False) | |
| is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) | |
| if is_gguf_weight_type: | |
| param.weight_type = loaded_weight.item() | |
| # Materialize GGUF UninitializedParameter | |
| if is_gguf_weight and isinstance(param, UninitializedParameter): | |
| weight_shape = list(loaded_weight.shape) | |
| if input_dim: | |
| weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size | |
| param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype) | |
| param_data = param.data | |
| # bitsandbytes loads the weights of the specific portion | |
| # no need to narrow here | |
| if ( | |
| input_dim is not None | |
| and not use_bitsandbytes_4bit | |
| and not self.use_presharded_weights | |
| ): | |
| shard_size = param_data.shape[input_dim] | |
| start_idx = self.tp_rank * shard_size | |
| if _is_cpu: | |
| from sglang.srt.model_loader.weight_utils import ( | |
| narrow_padded_param_and_loaded_weight, | |
| ) | |
| param_data, loaded_weight = narrow_padded_param_and_loaded_weight( | |
| param_data, | |
| loaded_weight, | |
| 0, # param_data_start | |
| start_idx, | |
| input_dim, | |
| shard_size, | |
| ) | |
| else: | |
| # Padding for special case like qwen2_5_VL's mlp which is not 8-aligned | |
| end_idx = start_idx + shard_size | |
| if end_idx > loaded_weight.shape[input_dim]: | |
| loaded_weight = pad_or_narrow_weight( | |
| loaded_weight, input_dim, start_idx, shard_size | |
| ) | |
| else: | |
| loaded_weight = loaded_weight.narrow( | |
| input_dim, start_idx, shard_size | |
| ) | |
| # Special case for loading scales off disk, which often do not | |
| # have a shape (such as in the case of AutoFP8). | |
| if len(loaded_weight.shape) == 0: | |
| loaded_weight = loaded_weight.reshape(1) | |
| assert param_data.shape == loaded_weight.shape | |
| param_data.copy_(loaded_weight) | |
| def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor): | |
| # Special case for loading scales off disk, which often do not | |
| # have a shape (such as in the case of AutoFP8). | |
| if len(loaded_weight.shape) == 0: | |
| assert loaded_weight.numel() == 1 | |
| loaded_weight = loaded_weight.reshape(1) | |
| if isinstance(param, RowvLLMParameter): | |
| # This `BasevLLMParameter` is defined in sglang/srt/layers/parameter.py, | |
| # It supports additional parameters like tp_rank and use_presharded_weights. | |
| param.load_row_parallel_weight( | |
| loaded_weight, | |
| tp_rank=self.tp_rank, | |
| use_presharded_weights=self.use_presharded_weights, | |
| ) | |
| else: | |
| # `params` is defined in `vllm/model_executor/parameter.py`, | |
| # It does not support additional parameters. | |
| param.load_row_parallel_weight(loaded_weight) | |
| def forward(self, input_, skip_all_reduce=False): | |
| if self.input_is_parallel: | |
| input_parallel = input_ | |
| else: | |
| splitted_input = split_tensor_along_last_dim( | |
| input_, num_partitions=self.tp_size | |
| ) | |
| input_parallel = splitted_input[self.tp_rank].contiguous() | |
| # Matrix multiply. | |
| assert self.quant_method is not None | |
| # Only fuse bias add into GEMM for rank 0 (this ensures that | |
| # bias will not get added more than once in TP>1 case) | |
| bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias | |
| with use_symmetric_memory(parallel_state.get_tp_group()) as sm: | |
| output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) | |
| sm.tag(output_parallel) | |
| if self.reduce_results and self.tp_size > 1 and not skip_all_reduce: | |
| output = tensor_model_parallel_all_reduce(output_parallel) | |
| else: | |
| output = output_parallel | |
| output_bias = self.bias if self.skip_bias_add else None | |
| return output, output_bias | |
| def extra_repr(self) -> str: | |
| s = f"input_features={self.input_size_per_partition}" | |
| s += f", output_features={self.output_size}" | |
| s += f", bias={self.bias is not None}" | |
| s += f", tp_size={self.tp_size}" | |
| s += f", reduce_results={self.reduce_results}" | |
| return s | |
Xet Storage Details
- Size:
- 56.1 kB
- Xet hash:
- 96b6eb6a3f8ff09358bf2cfe3ff7777441a078ef63c7c7e3f03ae5811b640eb5
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.