Harmony18090's picture
Add source batch 2/11
76f9669 verified
raw
history blame
16.2 kB
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
from typing import Generator, Optional, Tuple
import torch
from compressed_tensors.quantization.quant_args import (
FP4_E2M1_DATA,
FP8_E4M3_DATA,
FloatArgs,
QuantizationArgs,
QuantizationStrategy,
QuantizationType,
round_to_quantized_type_dtype,
)
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from compressed_tensors.quantization.utils.mxfp4_utils import (
generate_mxfp4_scales,
maybe_convert_from_mxfp4_exp,
should_generatre_mxfp4_scales,
)
from compressed_tensors.utils import deprecated
from loguru import logger
from torch import FloatTensor, IntTensor, Tensor
from torch.nn import Module
__all__ = [
"is_module_quantized",
"is_model_quantized",
"module_type",
"get_torch_bit_depth",
"can_quantize",
"KV_CACHE_TARGETS",
"is_kv_cache_quant_scheme",
"iter_named_leaf_modules",
"iter_named_quantizable_modules",
"compute_dynamic_scales_and_zp",
"calculate_range",
"calculate_qparams",
"generate_gparam",
"strategy_cdiv",
]
# target the self_attn layer
# QuantizedKVParameterCache is responsible for obtaining the k_scale and v_scale
KV_CACHE_TARGETS = ["re:.*self_attn$"]
_LOGGER: logging.Logger = logging.getLogger(__name__)
def calculate_qparams(
min_vals: Tensor,
max_vals: Tensor,
quantization_args: QuantizationArgs,
global_scale: Optional[Tensor] = None,
) -> Tuple[FloatTensor, IntTensor]:
"""
:param min_vals: tensor of min value(s) to calculate scale(s) and zero point(s)
from
:param max_vals: tensor of max value(s) to calculate scale(s) and zero point(s)
from
:param quantization_args: settings to quantization
:param global_scale: additional global scale to scale the locally generated scale
currently only applied/supported for Fp4
:return: tuple of the calculated scale(s) and zero point(s). For FP4, the calculated
scale is of dtype FP8
"""
# based on the implementations for consuming quantized values,
# 0.0 must always be representable within the quantized range
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
device = min_vals.device
bit_min, bit_max = calculate_range(quantization_args, device)
bit_range = bit_max - bit_min
# 1. Generate scale and zero-point
if quantization_args.symmetric:
max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
if should_generatre_mxfp4_scales(args=quantization_args):
scales = generate_mxfp4_scales(x=max_val_pos)
else:
scales = max_val_pos / (float(bit_range) / 2)
zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype)
else:
if (
quantization_args.num_bits == 4
and quantization_args.type == QuantizationType.FLOAT
):
raise NotImplementedError(
"Asymmetric Quantization is not supported for FP4"
)
scales = (max_vals - min_vals) / float(bit_range)
zero_points = bit_min - (min_vals / scales)
zero_points = torch.clamp(zero_points, bit_min, bit_max)
# 2. Conditionally scale the generated local scale by a global_scale
if global_scale is not None:
scales = global_scale * scales
# 3. Conditionally round the scale to the quantized dtype, if scale_dtype is set
if quantization_args.scale_dtype is not None:
scales = round_to_quantized_type_dtype(
scales, dtype=quantization_args.scale_dtype
)
# 4. Optionally remove exponent
scales = maybe_convert_from_mxfp4_exp(quantization_args, scales)
# 5. Update any 0s with small values to
# prevent div by 0
eps = _get_dtype_eps(
dtype=quantization_args.scale_dtype
if quantization_args.scale_dtype is not None
else scales.dtype
)
scales = torch.where(
scales == 0,
torch.tensor(eps, dtype=scales.dtype, device=device),
scales,
)
# 6. Round the zp to zp_dtype
zero_points = round_to_quantized_type_dtype(
zero_points, dtype=quantization_args.zp_dtype, cast_to_original_dtype=False
)
if scales.ndim == 0:
scales = scales.reshape(1)
zero_points = zero_points.reshape(1)
return scales, zero_points
def compute_dynamic_scales_and_zp(
value: Tensor,
args: QuantizationArgs,
module: torch.nn.Module,
global_scale: Optional[Tensor] = None,
):
"""
Returns the computed scales and zero points for dynamic activation
quantization.
:param value: tensor to calculate quantization parameters for
:param args: quantization args
:param reduce_dims: optional tuple of dimensions to reduce along,
returned scale and zero point will be shaped (1,) along the
reduced dimensions
:return: tuple of scale and zero point derived from the observed tensor
"""
keep_dims = True
if args.strategy == QuantizationStrategy.TOKEN:
dim = {0, 1}
reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim)
elif args.strategy == QuantizationStrategy.TENSOR:
reduce_dims = None
elif args.strategy in (
QuantizationStrategy.TENSOR_GROUP,
QuantizationStrategy.GROUP,
):
reduce_dims = -1
keep_dims = False
reshaped_dims = (
math.ceil(value.shape[-1] / args.group_size),
args.group_size,
)
value = value.unflatten(-1, reshaped_dims)
else:
supported_strategies = (
QuantizationStrategy.TOKEN,
QuantizationStrategy.TENSOR,
QuantizationStrategy.TENSOR_GROUP,
QuantizationStrategy.GROUP,
)
raise ValueError(
"Dynamic quantization is only supported for ",
f"{supported_strategies}",
)
if not reduce_dims:
min_val, max_val = torch.aminmax(value)
else:
min_val = torch.amin(value, dim=reduce_dims, keepdims=keep_dims)
max_val = torch.amax(value, dim=reduce_dims, keepdims=keep_dims)
return calculate_qparams(min_val, max_val, args, global_scale=global_scale)
def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple:
"""
Calculated the effective quantization range for the given Quantization Args
:param quantization_args: quantization args to get range of
:param device: device to store the range to
:return: tuple endpoints for the given quantization range
"""
if quantization_args.type == QuantizationType.INT:
bit_range = 2**quantization_args.num_bits
q_max = torch.tensor(bit_range / 2 - 1, device=device)
q_min = torch.tensor(-bit_range / 2, device=device)
elif quantization_args.type == QuantizationType.FLOAT:
if quantization_args.num_bits == 8:
q_max = torch.tensor(FP8_E4M3_DATA.max, device=device)
q_min = torch.tensor(FP8_E4M3_DATA.min, device=device)
elif quantization_args.num_bits == 4:
q_max = torch.tensor(FP4_E2M1_DATA.max, device=device)
q_min = torch.tensor(FP4_E2M1_DATA.min, device=device)
else:
raise NotImplementedError(
"Range calculation only supported for 4 and 8 bits"
)
else:
raise ValueError(f"Invalid quantization type {quantization_args.type}")
return q_min, q_max
def is_module_quantized(module: Module) -> bool:
"""
Check if a module is quantized, based on the existence of a non-empty quantization
scheme
:param module: pytorch module to check
:return: True if module is quantized, False otherwise
"""
if not hasattr(module, "quantization_scheme"):
return False
if module.quantization_scheme.weights is not None:
return True
if module.quantization_scheme.input_activations is not None:
return True
if module.quantization_scheme.output_activations is not None:
return True
return False
def is_model_quantized(model: Module) -> bool:
"""
Check if any modules in a model are quantized, based on the existence of a non-empty
quantization scheme in at least one module
:param model: pytorch model
:return: True if model is quantized, False otherwise
"""
return any(is_module_quantized(submodule) for submodule in model.modules())
def module_type(module: Module) -> str:
"""
Gets a string representation of a module type
:module: pytorch module to get type of
:return: module type as a string
"""
return type(module).__name__
@deprecated(
message="This function will be removed in a future release. "
"Please use `model.named_modules()` and filter by "
"compressed_tensors.InternalModule if neceessary"
)
def iter_named_leaf_modules(model: Module) -> Generator[Tuple[str, Module], None, None]:
"""
Yields modules that do not have any submodules except observers. The observers
themselves are not yielded
:param model: model to get leaf modules of
:returns: generator tuple of (name, leaf_submodule)
"""
for name, submodule in model.named_modules():
children = list(submodule.children())
# TODO: verify if an observer would ever be attached in this case/remove check
if len(children) == 0 and "observer" in name:
yield name, submodule
else:
if len(children) > 0:
named_children, children = zip(*list(submodule.named_children()))
has_non_observer_children = False
for i in range(len(children)):
child_name = named_children[i]
if "observer" not in child_name:
has_non_observer_children = True
if not has_non_observer_children:
yield name, submodule
@deprecated(
message="This function will be removed in a future release. "
"Please use `model.named_modules()` and filter by "
"compressed_tensors.InternalModule if neceessary"
)
def iter_named_quantizable_modules(
model: Module,
include_children: bool = True,
include_attn: bool = False,
include_mlp: bool = False,
) -> Generator[Tuple[str, Module], None, None]:
"""
Yield name and submodule of
- leaf modules, set by include_children
- attention modyles, set by include_attn
:param model: model to get leaf modules of
:param include_children: flag to get the leaf modules
:param inlcude_attn: flag to get the attention modules
:returns: generator tuple of (name, submodule)
"""
for name, submodule in model.named_modules():
# TODO: verify if an observer would ever be attached in this case/remove check
if include_children:
children = list(submodule.children())
if len(children) == 0 and "observer" not in name:
yield name, submodule
else:
if len(children) > 0:
named_children, children = zip(*list(submodule.named_children()))
has_non_observer_children = False
for i in range(len(children)):
child_name = named_children[i]
if "observer" not in child_name:
has_non_observer_children = True
if not has_non_observer_children:
yield name, submodule
if include_attn:
if name.endswith("self_attn"):
yield name, submodule
if include_mlp:
if name.endswith("mlp"):
yield name, submodule
def get_torch_bit_depth(value: torch.Tensor) -> int:
"""
Determine the number of bits used to represent the dtype of a tensor
:param value: tensor to check bit depth of
:return: bit depth of each element in the value tensor
"""
try:
bit_depth = torch.finfo(value.dtype).bits
except TypeError:
bit_depth = torch.iinfo(value.dtype).bits
return bit_depth
def can_quantize(value: torch.Tensor, quant_args: "QuantizationArgs") -> bool: # noqa
"""
Checks if value can be quantized by quant_args.
:param value: tensor to check for quantization
:param quant_args: QuantizationArgs to use for quantization
:return: False if value is already quantized to quant_args or value is incompatible
with quant_args, True if value can be quantized with quant_args
"""
bit_depth = get_torch_bit_depth(value)
requested_depth = quant_args.num_bits
if bit_depth < quant_args.num_bits:
_LOGGER.warn(
f"Can't quantize tensor with bit depth {bit_depth} to {requested_depth}."
"The QuantizationArgs provided are not compatible with the input tensor."
)
return bit_depth > quant_args.num_bits
@deprecated()
def is_kv_cache_quant_scheme(scheme: QuantizationScheme) -> bool:
"""
Check whether the QuantizationScheme targets the kv cache.
It does if all the following criteria are met:
- the scheme targets either exactly match the KV_CACHE_TARGETS
or the match KV_CACHE_TARGETS regex pattern
- the scheme quantizes output_activations (we want to quantize the
outputs from the KV_CACHE_TARGETS, as their correspond to the
keys and values that are to be saved in the cache)
:param scheme: The QuantizationScheme to investigate
:return: boolean flag
"""
for target in scheme.targets:
if target in KV_CACHE_TARGETS:
return True
return False
def generate_gparam(
updated_min_val: torch.Tensor,
updated_max_val: torch.Tensor,
scale_data: Optional[FloatArgs] = FP8_E4M3_DATA,
quant_data: Optional[FloatArgs] = FP4_E2M1_DATA,
dtype: Optional[torch.dtype] = torch.float32,
):
"""
Generate a global scale for an entire tensor (input_tensor).
Goal of the scale is to ensure that the quantization (local) scale
falls into the approproiate dtype range.
E.g. for NVFP4, group (local) scales are in dtype FP8. The global_scale
attempts to use the entire FP8 dtype range while mapping a per-group max
to the FP4 max.
"""
min_vals = torch.min(updated_min_val, torch.zeros_like(updated_min_val))
max_vals = torch.max(updated_max_val, torch.zeros_like(updated_max_val))
max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
global_scale = scale_data.max * quant_data.max / max_val_pos
return global_scale.to(dtype).reshape([1])
def strategy_cdiv(
value: int,
divisor: int,
strategy: Optional[QuantizationStrategy],
strict: bool = False,
) -> int:
dividend = math.ceil(value / divisor)
if dividend * divisor != value:
message = (
f"{strategy} quantization strategy requires strict division of "
f"weight/activation size {value} and group/block size {divisor}. "
"consider reducing the group/block size or ignoring modules with "
f"weights not divisible by {divisor}"
)
if strict:
raise ValueError(message)
else:
logger.bind(log_once=True).warning(message)
return dividend
def _get_dtype_eps(dtype: torch.dtype) -> float:
if dtype == FP8_E4M3_DATA.dtype:
return 0.125
elif dtype == FP4_E2M1_DATA.dtype:
return 0.25
elif torch.is_floating_point(torch.tensor([], dtype=dtype)):
return torch.finfo(dtype).eps
else:
return 1