# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings from copy import deepcopy from typing import List, Optional import torch from compressed_tensors.config import CompressionFormat from compressed_tensors.quantization.quant_args import ( FP8_E4M3_DATA, DynamicType, QuantizationArgs, QuantizationStrategy, QuantizationType, ) from pydantic import BaseModel, ConfigDict, model_validator __all__ = [ "QuantizationScheme", "preset_name_to_scheme", "is_preset_scheme", ] class QuantizationScheme(BaseModel): """ Set of QuantizationArgs defining how the weights, inputs and outputs of target list of modules should be quantized :param targets: list of modules to apply the QuantizationArgs to, can be layer names, layer types or a regular expression, typically ["Linear"] :param weights: quantization config for layer weights :param input_activations: quantization config for layer inputs :param output_activations: quantization config for layer outputs :param format: CompressionFormat for the layer """ targets: List[str] weights: Optional[QuantizationArgs] = None input_activations: Optional[QuantizationArgs] = None output_activations: Optional[QuantizationArgs] = None format: Optional[str] = None @model_validator(mode="after") def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme": inputs = model.input_activations outputs = model.output_activations weights = model.weights format = model.format if inputs is not None: if inputs.strategy not in ( QuantizationStrategy.TOKEN, QuantizationStrategy.TENSOR, QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP, QuantizationStrategy.ATTN_HEAD, ): if ( inputs.strategy == QuantizationStrategy.GROUP and inputs.dynamic is True ): raise NotImplementedError( "Static and local group-wise activation " "quantization is not supported" ) raise NotImplementedError( f"Using {inputs.strategy} strategy is not supported for " "activation quantization" ) if inputs.actorder is not None: raise ValueError("Cannot apply actorder to input activations") if outputs is not None: if outputs.actorder is not None: raise ValueError("Cannot apply actorder to output activations") if format == CompressionFormat.mixed_precision.value: raise ValueError( "mixed-precision cannot be set as a format for a QuantizationScheme" ) if ( inputs and weights and weights.strategy == QuantizationStrategy.GROUP and inputs.strategy == QuantizationStrategy.GROUP and weights.group_size != inputs.group_size ): warnings.warn( "Using GROUP strategy for both weights and input_activations " f"with different group sizes ({weights.group_size} vs " f"{inputs.group_size}) may complicate fused kernel implementations. " "Consider using TENSOR_GROUP strategy for both or matching group" " sizes.", UserWarning, stacklevel=2, ) return model model_config = ConfigDict(extra="forbid") """ Pre-Set Quantization Scheme Args """ def preset_name_to_scheme(name: str, targets: List[str]) -> QuantizationScheme: """ :param name: preset quantization settings name. must exist in upper case in PRESET_SCHEMES :param targets: list of quantization targets to be passed to the Scheme :return: new QuantizationScheme for a given name with the given targets """ name = name.upper() if name not in PRESET_SCHEMES: raise KeyError( f"Unknown preset scheme name {name}, " f"available names: {list(PRESET_SCHEMES.keys())}" ) scheme_args = deepcopy(PRESET_SCHEMES[name]) # deepcopy to avoid args references return QuantizationScheme( targets=targets, **scheme_args, ) def is_preset_scheme(name: str) -> bool: """ :param name: preset quantization settings name :return: True if the name is a preset scheme name """ return name.upper() in PRESET_SCHEMES UNQUANTIZED = dict() NVFP4A16 = dict( weights=QuantizationArgs( num_bits=4, type=QuantizationType.FLOAT, strategy=QuantizationStrategy.TENSOR_GROUP, symmetric=True, dynamic=False, group_size=16, scale_dtype=FP8_E4M3_DATA.dtype, zp_dtype=FP8_E4M3_DATA.dtype, ) ) NVFP4 = dict( weights=QuantizationArgs( num_bits=4, type=QuantizationType.FLOAT, strategy=QuantizationStrategy.TENSOR_GROUP, symmetric=True, dynamic=False, group_size=16, observer="static_minmax", scale_dtype=FP8_E4M3_DATA.dtype, zp_dtype=FP8_E4M3_DATA.dtype, ), input_activations=QuantizationArgs( num_bits=4, type=QuantizationType.FLOAT, strategy=QuantizationStrategy.TENSOR_GROUP, symmetric=True, dynamic=DynamicType.LOCAL, group_size=16, observer="static_minmax", scale_dtype=FP8_E4M3_DATA.dtype, zp_dtype=FP8_E4M3_DATA.dtype, ), ) MXFP4A16 = dict( weights=QuantizationArgs( num_bits=4, type=QuantizationType.FLOAT, strategy=QuantizationStrategy.GROUP, symmetric=True, dynamic=False, group_size=32, scale_dtype=torch.uint8, zp_dtype=torch.uint8, ) ) MXFP4 = dict( weights=QuantizationArgs( num_bits=4, type=QuantizationType.FLOAT, strategy=QuantizationStrategy.GROUP, symmetric=True, dynamic=False, group_size=32, scale_dtype=torch.uint8, zp_dtype=torch.uint8, ), input_activations=QuantizationArgs( num_bits=4, type=QuantizationType.FLOAT, strategy=QuantizationStrategy.GROUP, dynamic=True, symmetric=True, group_size=32, scale_dtype=torch.uint8, zp_dtype=torch.uint8, ), ) # 8 bit integer weights and 8 bit activations quantization INT8_W8A8 = dict( weights=QuantizationArgs( num_bits=8, type=QuantizationType.INT, strategy=QuantizationStrategy.CHANNEL, symmetric=True, dynamic=False, ), input_activations=QuantizationArgs( num_bits=8, type=QuantizationType.INT, strategy=QuantizationStrategy.TOKEN, symmetric=True, dynamic=True, observer=None, ), ) # 8 bit integer weights only quantization W8A16 = dict( weights=QuantizationArgs( num_bits=8, type=QuantizationType.INT, strategy=QuantizationStrategy.CHANNEL, symmetric=True, dynamic=False, ), ) # 4 bit integer weights only quantization W4A16 = dict( weights=QuantizationArgs( num_bits=4, type=QuantizationType.INT, strategy=QuantizationStrategy.GROUP, group_size=128, symmetric=True, dynamic=False, ), ) # 4 bit integer weights only asymmetric quantization W4A16_ASYM = dict( weights=QuantizationArgs( num_bits=4, type=QuantizationType.INT, strategy=QuantizationStrategy.GROUP, group_size=128, symmetric=False, dynamic=False, ), ) # 4 bit integer weights and 8 bit activations quantization INT8_W4A8 = dict( weights=QuantizationArgs( num_bits=4, type=QuantizationType.INT, group_size=128, strategy=QuantizationStrategy.GROUP, symmetric=True, dynamic=False, ), input_activations=QuantizationArgs( num_bits=8, type=QuantizationType.INT, strategy=QuantizationStrategy.TOKEN, symmetric=True, dynamic=True, observer=None, ), ) # FP8 weights and FP8 activations quantization FP8 = dict( weights=QuantizationArgs( num_bits=8, type=QuantizationType.FLOAT, strategy=QuantizationStrategy.TENSOR, symmetric=True, dynamic=False, ), input_activations=QuantizationArgs( num_bits=8, type=QuantizationType.FLOAT, strategy=QuantizationStrategy.TENSOR, symmetric=True, dynamic=False, ), ) # FP8 weights and FP8 dynamic activations quantization FP8_DYNAMIC = dict( weights=QuantizationArgs( num_bits=8, type=QuantizationType.FLOAT, strategy=QuantizationStrategy.CHANNEL, symmetric=True, dynamic=False, ), input_activations=QuantizationArgs( num_bits=8, type=QuantizationType.FLOAT, strategy=QuantizationStrategy.TOKEN, symmetric=True, dynamic=True, observer=None, ), ) # Block‐wise FP8 (deepseekv3-style quantization): # static 128x128 per‐block weights and # dynamic per‐token‐group activations FP8_BLOCK = dict( weights=QuantizationArgs( num_bits=8, type=QuantizationType.FLOAT, strategy=QuantizationStrategy.BLOCK, symmetric=True, dynamic=False, block_structure=[128, 128], ), input_activations=QuantizationArgs( num_bits=8, type=QuantizationType.FLOAT, strategy=QuantizationStrategy.GROUP, symmetric=True, dynamic=True, observer=None, group_size=128, ), ) PRESET_SCHEMES = { # Unquantized (no-op) "UNQUANTIZED": UNQUANTIZED, # Integer weight only schemes "W8A16": W8A16, "W4A16": W4A16, "W4A16_ASYM": W4A16_ASYM, # Integer weight and activation schemes "W8A8": INT8_W8A8, "INT8": INT8_W8A8, # alias for W8A8 "W4A8": INT8_W4A8, # Float weight and activation schemes "FP8": FP8, "FP8_DYNAMIC": FP8_DYNAMIC, "FP8_BLOCK": FP8_BLOCK, "NVFP4A16": NVFP4A16, "NVFP4": NVFP4, "MXFP4A16": MXFP4A16, "MXFP4": MXFP4, }