| # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py | |
| import logging | |
| from typing import Any, Dict, List, Optional | |
| import regex as re | |
| import torch | |
| from torch.nn.parameter import Parameter | |
| from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod | |
| from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter | |
| from sglang.srt.layers.quantization.base_config import ( | |
| LinearMethodBase, | |
| QuantizationConfig, | |
| QuantizeMethodBase, | |
| ) | |
| from sglang.srt.layers.quantization.petit_utils import ( | |
| apply_petit_nvfp4_linear, | |
| prepare_nvfp4_layer_for_petit, | |
| verify_petit_nvfp4_supported, | |
| ) | |
| from sglang.srt.layers.quantization.utils import is_layer_skipped | |
| from sglang.srt.utils import is_hip | |
| _is_hip = is_hip() | |
| # Initialize logger for the module | |
| logger = logging.getLogger(__name__) | |
| # Configuration class to support the NVFP4 quantized model generated by the ModelOpt quantization tool | |
| class PetitNvFp4Config(QuantizationConfig): | |
| """Config class for Petit FP4.""" | |
| def __init__( | |
| self, | |
| is_checkpoint_nvfp4_serialized: bool = False, | |
| kv_cache_quant_algo: str = None, | |
| group_size: int = None, | |
| exclude_modules: List[str] = None, | |
| ) -> None: | |
| self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized | |
| if is_checkpoint_nvfp4_serialized: | |
| logger.warning( | |
| "Detected nvfp4 checkpoint. Please note that the " | |
| "format is experimental and subject to change." | |
| ) | |
| self.group_size = group_size | |
| self.kv_cache_quant_algo = kv_cache_quant_algo | |
| self.exclude_modules = exclude_modules | |
| def get_name(cls) -> str: | |
| return "petit_nvfp4" | |
| def get_supported_act_dtypes(cls) -> List[torch.dtype]: | |
| return [torch.bfloat16, torch.half] | |
| def get_min_capability(cls) -> int: | |
| # Petit supports the gfx90a and gfx942 GPUs | |
| return 90 | |
| def get_config_filenames(cls) -> List[str]: | |
| return ["hf_quant_config.json"] | |
| def from_config(cls, config: Dict[str, Any]) -> "PetitNvFp4Config": | |
| quant_config = cls.get_from_keys(config, ["quantization"]) | |
| quant_method = quant_config["quant_algo"] | |
| group_size = quant_config.get("group_size", None) | |
| verify_petit_nvfp4_supported(quant_method, group_size) | |
| is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method | |
| kv_cache_quant_algo = quant_config["kv_cache_quant_algo"] | |
| if not kv_cache_quant_algo: | |
| kv_cache_quant_algo = "auto" | |
| exclude_modules = quant_config.get("exclude_modules", None) | |
| if not (group_size and kv_cache_quant_algo and (exclude_modules is not None)): | |
| logger.warning( | |
| f"group_size: {group_size}," | |
| f"kv_cache_quant_algo: {kv_cache_quant_algo}," | |
| f"exclude_modules: {exclude_modules}" | |
| ) | |
| raise ValueError( | |
| "NVFP4 quantization requires group size and " | |
| "kv_cache_quant_algo specified in " | |
| "hf_quant_config.json" | |
| ) | |
| return cls( | |
| is_checkpoint_nvfp4_serialized, | |
| kv_cache_quant_algo, | |
| group_size, | |
| exclude_modules, | |
| ) | |
| def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: | |
| can_convert = cls.is_petit_nvfp4_compatible(hf_quant_cfg) | |
| if can_convert: | |
| return cls.get_name() | |
| return None | |
| def is_petit_nvfp4_compatible(cls, quant_config: Dict[str, Any]) -> bool: | |
| quant_method = quant_config.get("quant_method", "").lower() | |
| return _is_hip and quant_method == "modelopt" | |
| def is_layer_excluded(self, prefix: str, exclude_modules: list): | |
| for pattern in exclude_modules: | |
| regex_str = pattern.replace(".", r"\.").replace("*", r".*") | |
| if re.fullmatch(regex_str, prefix): | |
| return True | |
| return False | |
| def get_quant_method( | |
| self, layer: torch.nn.Module, prefix: str | |
| ) -> Optional["QuantizeMethodBase"]: | |
| if isinstance(layer, LinearBase): | |
| if is_layer_skipped(prefix, self.exclude_modules) or self.is_layer_excluded( | |
| prefix, self.exclude_modules | |
| ): | |
| return UnquantizedLinearMethod() | |
| return PetitNvFp4LinearMethod(self) | |
| return None | |
| def get_scaled_act_names(self) -> List[str]: | |
| return [] | |
| class PetitNvFp4LinearMethod(LinearMethodBase): | |
| """Linear method for NVFP4. | |
| Supports loading NVFP4 checkpoints with the following structure: | |
| |Tensor Name | datatype | shape | | |
| |----------------------------------------------------| | |
| |input_scale | torch.float32 | scalar | | |
| |weight | NVFP4(SE2M1) | [1, X, y/2] | | |
| |weight_scale | FP8-E4M3 | [X, Y] | | |
| |weight_scale_2 | torch.float32 | scalar | | |
| The weights are quantized per block of 16 elements. | |
| Args: quant_config: The ModelOpt quantization config. | |
| """ | |
| def __init__(self, quant_config: PetitNvFp4Config): | |
| self.quant_config = quant_config | |
| def create_weights( | |
| self, | |
| layer: torch.nn.Module, | |
| input_size_per_partition: int, | |
| output_partition_sizes: List[int], | |
| input_size: int, | |
| output_size: int, | |
| params_dtype: torch.dtype, | |
| **extra_weight_attrs, | |
| ): | |
| del input_size, output_size | |
| if not self.quant_config.is_checkpoint_nvfp4_serialized: | |
| raise ValueError( | |
| "NVFP4 quantization was selected, " | |
| " dynamic quantization is not supported." | |
| ) | |
| output_size_per_partition = sum(output_partition_sizes) | |
| weight_loader = extra_weight_attrs.get("weight_loader") | |
| layer.logical_widths = output_partition_sizes | |
| layer.input_size_per_partition = input_size_per_partition | |
| layer.output_size_per_partition = output_size_per_partition | |
| if input_size_per_partition % 16 != 0: | |
| raise ValueError( | |
| "Unsupported model when in features size is " "not multiple of 16" | |
| ) | |
| weight_dtype = ( | |
| torch.float8_e4m3fn | |
| if self.quant_config.is_checkpoint_nvfp4_serialized | |
| else params_dtype | |
| ) | |
| weight = ModelWeightParameter( | |
| data=torch.empty( | |
| # 2 fp4 data is packed in one uint8 in the input dimension | |
| output_size_per_partition, | |
| input_size_per_partition // 2, | |
| dtype=torch.uint8, | |
| ), | |
| input_dim=1, | |
| output_dim=0, | |
| weight_loader=weight_loader, | |
| ) | |
| layer.register_parameter("weight", weight) | |
| input_scale = PerTensorScaleParameter( | |
| data=torch.empty(len(output_partition_sizes), dtype=torch.float32), | |
| weight_loader=weight_loader, | |
| ) | |
| layer.register_parameter("input_scale", input_scale) | |
| weight_scale_2 = PerTensorScaleParameter( | |
| data=torch.empty(len(output_partition_sizes), dtype=torch.float32), | |
| weight_loader=weight_loader, | |
| ) | |
| layer.register_parameter("weight_scale_2", weight_scale_2) | |
| weight_scale = ModelWeightParameter( | |
| data=torch.empty( | |
| output_size_per_partition, | |
| input_size_per_partition // self.quant_config.group_size, | |
| dtype=weight_dtype, | |
| ), | |
| input_dim=1, | |
| output_dim=0, | |
| weight_loader=weight_loader, | |
| ) | |
| layer.register_parameter("weight_scale", weight_scale) | |
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | |
| input_scale_2 = layer.input_scale.max().to(torch.float32) | |
| weight_scale_2 = layer.weight_scale_2.max().to(torch.float32) | |
| layer.input_scale = Parameter(input_scale_2, requires_grad=False) | |
| layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False) | |
| layer.alpha = Parameter( | |
| layer.input_scale * layer.weight_scale_2, requires_grad=False | |
| ) | |
| prepare_nvfp4_layer_for_petit(layer) | |
| del layer.input_scale | |
| def apply( | |
| self, | |
| layer: torch.nn.Module, | |
| x: torch.Tensor, | |
| bias: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| return apply_petit_nvfp4_linear( | |
| input=x, | |
| weight=layer.weight, | |
| weight_scale=layer.weight_scale, | |
| weight_scale_2=layer.weight_scale_2, | |
| size_n=layer.output_size_per_partition, | |
| size_k=layer.input_size_per_partition, | |
| bias=bias, | |
| ) | |
Xet Storage Details
- Size:
- 8.94 kB
- Xet hash:
- cebbe8f418f3c0632265ea0001e5d1d93ed4128849e6c2f1d5205bf1cc427329
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.