# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import math from typing import Generator, Optional, Tuple import torch from compressed_tensors.quantization.quant_args import ( FP4_E2M1_DATA, FP8_E4M3_DATA, FloatArgs, QuantizationArgs, QuantizationStrategy, QuantizationType, round_to_quantized_type_dtype, ) from compressed_tensors.quantization.quant_scheme import QuantizationScheme from compressed_tensors.quantization.utils.mxfp4_utils import ( generate_mxfp4_scales, maybe_convert_from_mxfp4_exp, should_generatre_mxfp4_scales, ) from compressed_tensors.utils import deprecated from loguru import logger from torch import FloatTensor, IntTensor, Tensor from torch.nn import Module __all__ = [ "is_module_quantized", "is_model_quantized", "module_type", "get_torch_bit_depth", "can_quantize", "KV_CACHE_TARGETS", "is_kv_cache_quant_scheme", "iter_named_leaf_modules", "iter_named_quantizable_modules", "compute_dynamic_scales_and_zp", "calculate_range", "calculate_qparams", "generate_gparam", "strategy_cdiv", ] # target the self_attn layer # QuantizedKVParameterCache is responsible for obtaining the k_scale and v_scale KV_CACHE_TARGETS = ["re:.*self_attn$"] _LOGGER: logging.Logger = logging.getLogger(__name__) def calculate_qparams( min_vals: Tensor, max_vals: Tensor, quantization_args: QuantizationArgs, global_scale: Optional[Tensor] = None, ) -> Tuple[FloatTensor, IntTensor]: """ :param min_vals: tensor of min value(s) to calculate scale(s) and zero point(s) from :param max_vals: tensor of max value(s) to calculate scale(s) and zero point(s) from :param quantization_args: settings to quantization :param global_scale: additional global scale to scale the locally generated scale currently only applied/supported for Fp4 :return: tuple of the calculated scale(s) and zero point(s). For FP4, the calculated scale is of dtype FP8 """ # based on the implementations for consuming quantized values, # 0.0 must always be representable within the quantized range min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) device = min_vals.device bit_min, bit_max = calculate_range(quantization_args, device) bit_range = bit_max - bit_min # 1. Generate scale and zero-point if quantization_args.symmetric: max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) if should_generatre_mxfp4_scales(args=quantization_args): scales = generate_mxfp4_scales(x=max_val_pos) else: scales = max_val_pos / (float(bit_range) / 2) zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype) else: if ( quantization_args.num_bits == 4 and quantization_args.type == QuantizationType.FLOAT ): raise NotImplementedError( "Asymmetric Quantization is not supported for FP4" ) scales = (max_vals - min_vals) / float(bit_range) zero_points = bit_min - (min_vals / scales) zero_points = torch.clamp(zero_points, bit_min, bit_max) # 2. Conditionally scale the generated local scale by a global_scale if global_scale is not None: scales = global_scale * scales # 3. Conditionally round the scale to the quantized dtype, if scale_dtype is set if quantization_args.scale_dtype is not None: scales = round_to_quantized_type_dtype( scales, dtype=quantization_args.scale_dtype ) # 4. Optionally remove exponent scales = maybe_convert_from_mxfp4_exp(quantization_args, scales) # 5. Update any 0s with small values to # prevent div by 0 eps = _get_dtype_eps( dtype=quantization_args.scale_dtype if quantization_args.scale_dtype is not None else scales.dtype ) scales = torch.where( scales == 0, torch.tensor(eps, dtype=scales.dtype, device=device), scales, ) # 6. Round the zp to zp_dtype zero_points = round_to_quantized_type_dtype( zero_points, dtype=quantization_args.zp_dtype, cast_to_original_dtype=False ) if scales.ndim == 0: scales = scales.reshape(1) zero_points = zero_points.reshape(1) return scales, zero_points def compute_dynamic_scales_and_zp( value: Tensor, args: QuantizationArgs, module: torch.nn.Module, global_scale: Optional[Tensor] = None, ): """ Returns the computed scales and zero points for dynamic activation quantization. :param value: tensor to calculate quantization parameters for :param args: quantization args :param reduce_dims: optional tuple of dimensions to reduce along, returned scale and zero point will be shaped (1,) along the reduced dimensions :return: tuple of scale and zero point derived from the observed tensor """ keep_dims = True if args.strategy == QuantizationStrategy.TOKEN: dim = {0, 1} reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim) elif args.strategy == QuantizationStrategy.TENSOR: reduce_dims = None elif args.strategy in ( QuantizationStrategy.TENSOR_GROUP, QuantizationStrategy.GROUP, ): reduce_dims = -1 keep_dims = False reshaped_dims = ( math.ceil(value.shape[-1] / args.group_size), args.group_size, ) value = value.unflatten(-1, reshaped_dims) else: supported_strategies = ( QuantizationStrategy.TOKEN, QuantizationStrategy.TENSOR, QuantizationStrategy.TENSOR_GROUP, QuantizationStrategy.GROUP, ) raise ValueError( "Dynamic quantization is only supported for ", f"{supported_strategies}", ) if not reduce_dims: min_val, max_val = torch.aminmax(value) else: min_val = torch.amin(value, dim=reduce_dims, keepdims=keep_dims) max_val = torch.amax(value, dim=reduce_dims, keepdims=keep_dims) return calculate_qparams(min_val, max_val, args, global_scale=global_scale) def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple: """ Calculated the effective quantization range for the given Quantization Args :param quantization_args: quantization args to get range of :param device: device to store the range to :return: tuple endpoints for the given quantization range """ if quantization_args.type == QuantizationType.INT: bit_range = 2**quantization_args.num_bits q_max = torch.tensor(bit_range / 2 - 1, device=device) q_min = torch.tensor(-bit_range / 2, device=device) elif quantization_args.type == QuantizationType.FLOAT: if quantization_args.num_bits == 8: q_max = torch.tensor(FP8_E4M3_DATA.max, device=device) q_min = torch.tensor(FP8_E4M3_DATA.min, device=device) elif quantization_args.num_bits == 4: q_max = torch.tensor(FP4_E2M1_DATA.max, device=device) q_min = torch.tensor(FP4_E2M1_DATA.min, device=device) else: raise NotImplementedError( "Range calculation only supported for 4 and 8 bits" ) else: raise ValueError(f"Invalid quantization type {quantization_args.type}") return q_min, q_max def is_module_quantized(module: Module) -> bool: """ Check if a module is quantized, based on the existence of a non-empty quantization scheme :param module: pytorch module to check :return: True if module is quantized, False otherwise """ if not hasattr(module, "quantization_scheme"): return False if module.quantization_scheme.weights is not None: return True if module.quantization_scheme.input_activations is not None: return True if module.quantization_scheme.output_activations is not None: return True return False def is_model_quantized(model: Module) -> bool: """ Check if any modules in a model are quantized, based on the existence of a non-empty quantization scheme in at least one module :param model: pytorch model :return: True if model is quantized, False otherwise """ return any(is_module_quantized(submodule) for submodule in model.modules()) def module_type(module: Module) -> str: """ Gets a string representation of a module type :module: pytorch module to get type of :return: module type as a string """ return type(module).__name__ @deprecated( message="This function will be removed in a future release. " "Please use `model.named_modules()` and filter by " "compressed_tensors.InternalModule if neceessary" ) def iter_named_leaf_modules(model: Module) -> Generator[Tuple[str, Module], None, None]: """ Yields modules that do not have any submodules except observers. The observers themselves are not yielded :param model: model to get leaf modules of :returns: generator tuple of (name, leaf_submodule) """ for name, submodule in model.named_modules(): children = list(submodule.children()) # TODO: verify if an observer would ever be attached in this case/remove check if len(children) == 0 and "observer" in name: yield name, submodule else: if len(children) > 0: named_children, children = zip(*list(submodule.named_children())) has_non_observer_children = False for i in range(len(children)): child_name = named_children[i] if "observer" not in child_name: has_non_observer_children = True if not has_non_observer_children: yield name, submodule @deprecated( message="This function will be removed in a future release. " "Please use `model.named_modules()` and filter by " "compressed_tensors.InternalModule if neceessary" ) def iter_named_quantizable_modules( model: Module, include_children: bool = True, include_attn: bool = False, include_mlp: bool = False, ) -> Generator[Tuple[str, Module], None, None]: """ Yield name and submodule of - leaf modules, set by include_children - attention modyles, set by include_attn :param model: model to get leaf modules of :param include_children: flag to get the leaf modules :param inlcude_attn: flag to get the attention modules :returns: generator tuple of (name, submodule) """ for name, submodule in model.named_modules(): # TODO: verify if an observer would ever be attached in this case/remove check if include_children: children = list(submodule.children()) if len(children) == 0 and "observer" not in name: yield name, submodule else: if len(children) > 0: named_children, children = zip(*list(submodule.named_children())) has_non_observer_children = False for i in range(len(children)): child_name = named_children[i] if "observer" not in child_name: has_non_observer_children = True if not has_non_observer_children: yield name, submodule if include_attn: if name.endswith("self_attn"): yield name, submodule if include_mlp: if name.endswith("mlp"): yield name, submodule def get_torch_bit_depth(value: torch.Tensor) -> int: """ Determine the number of bits used to represent the dtype of a tensor :param value: tensor to check bit depth of :return: bit depth of each element in the value tensor """ try: bit_depth = torch.finfo(value.dtype).bits except TypeError: bit_depth = torch.iinfo(value.dtype).bits return bit_depth def can_quantize(value: torch.Tensor, quant_args: "QuantizationArgs") -> bool: # noqa """ Checks if value can be quantized by quant_args. :param value: tensor to check for quantization :param quant_args: QuantizationArgs to use for quantization :return: False if value is already quantized to quant_args or value is incompatible with quant_args, True if value can be quantized with quant_args """ bit_depth = get_torch_bit_depth(value) requested_depth = quant_args.num_bits if bit_depth < quant_args.num_bits: _LOGGER.warn( f"Can't quantize tensor with bit depth {bit_depth} to {requested_depth}." "The QuantizationArgs provided are not compatible with the input tensor." ) return bit_depth > quant_args.num_bits @deprecated() def is_kv_cache_quant_scheme(scheme: QuantizationScheme) -> bool: """ Check whether the QuantizationScheme targets the kv cache. It does if all the following criteria are met: - the scheme targets either exactly match the KV_CACHE_TARGETS or the match KV_CACHE_TARGETS regex pattern - the scheme quantizes output_activations (we want to quantize the outputs from the KV_CACHE_TARGETS, as their correspond to the keys and values that are to be saved in the cache) :param scheme: The QuantizationScheme to investigate :return: boolean flag """ for target in scheme.targets: if target in KV_CACHE_TARGETS: return True return False def generate_gparam( updated_min_val: torch.Tensor, updated_max_val: torch.Tensor, scale_data: Optional[FloatArgs] = FP8_E4M3_DATA, quant_data: Optional[FloatArgs] = FP4_E2M1_DATA, dtype: Optional[torch.dtype] = torch.float32, ): """ Generate a global scale for an entire tensor (input_tensor). Goal of the scale is to ensure that the quantization (local) scale falls into the approproiate dtype range. E.g. for NVFP4, group (local) scales are in dtype FP8. The global_scale attempts to use the entire FP8 dtype range while mapping a per-group max to the FP4 max. """ min_vals = torch.min(updated_min_val, torch.zeros_like(updated_min_val)) max_vals = torch.max(updated_max_val, torch.zeros_like(updated_max_val)) max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) global_scale = scale_data.max * quant_data.max / max_val_pos return global_scale.to(dtype).reshape([1]) def strategy_cdiv( value: int, divisor: int, strategy: Optional[QuantizationStrategy], strict: bool = False, ) -> int: dividend = math.ceil(value / divisor) if dividend * divisor != value: message = ( f"{strategy} quantization strategy requires strict division of " f"weight/activation size {value} and group/block size {divisor}. " "consider reducing the group/block size or ignoring modules with " f"weights not divisible by {divisor}" ) if strict: raise ValueError(message) else: logger.bind(log_once=True).warning(message) return dividend def _get_dtype_eps(dtype: torch.dtype) -> float: if dtype == FP8_E4M3_DATA.dtype: return 0.125 elif dtype == FP4_E2M1_DATA.dtype: return 0.25 elif torch.is_floating_point(torch.tensor([], dtype=dtype)): return torch.finfo(dtype).eps else: return 1