Harmony18090's picture
Add source batch 2/11
76f9669 verified
raw
history blame
11 kB
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from copy import deepcopy
from typing import List, Optional
import torch
from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization.quant_args import (
FP8_E4M3_DATA,
DynamicType,
QuantizationArgs,
QuantizationStrategy,
QuantizationType,
)
from pydantic import BaseModel, ConfigDict, model_validator
__all__ = [
"QuantizationScheme",
"preset_name_to_scheme",
"is_preset_scheme",
]
class QuantizationScheme(BaseModel):
"""
Set of QuantizationArgs defining how the weights, inputs and outputs of target list
of modules should be quantized
:param targets: list of modules to apply the QuantizationArgs to, can be layer
names, layer types or a regular expression, typically ["Linear"]
:param weights: quantization config for layer weights
:param input_activations: quantization config for layer inputs
:param output_activations: quantization config for layer outputs
:param format: CompressionFormat for the layer
"""
targets: List[str]
weights: Optional[QuantizationArgs] = None
input_activations: Optional[QuantizationArgs] = None
output_activations: Optional[QuantizationArgs] = None
format: Optional[str] = None
@model_validator(mode="after")
def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":
inputs = model.input_activations
outputs = model.output_activations
weights = model.weights
format = model.format
if inputs is not None:
if inputs.strategy not in (
QuantizationStrategy.TOKEN,
QuantizationStrategy.TENSOR,
QuantizationStrategy.GROUP,
QuantizationStrategy.TENSOR_GROUP,
QuantizationStrategy.ATTN_HEAD,
):
if (
inputs.strategy == QuantizationStrategy.GROUP
and inputs.dynamic is True
):
raise NotImplementedError(
"Static and local group-wise activation "
"quantization is not supported"
)
raise NotImplementedError(
f"Using {inputs.strategy} strategy is not supported for "
"activation quantization"
)
if inputs.actorder is not None:
raise ValueError("Cannot apply actorder to input activations")
if outputs is not None:
if outputs.actorder is not None:
raise ValueError("Cannot apply actorder to output activations")
if format == CompressionFormat.mixed_precision.value:
raise ValueError(
"mixed-precision cannot be set as a format for a QuantizationScheme"
)
if (
inputs
and weights
and weights.strategy == QuantizationStrategy.GROUP
and inputs.strategy == QuantizationStrategy.GROUP
and weights.group_size != inputs.group_size
):
warnings.warn(
"Using GROUP strategy for both weights and input_activations "
f"with different group sizes ({weights.group_size} vs "
f"{inputs.group_size}) may complicate fused kernel implementations. "
"Consider using TENSOR_GROUP strategy for both or matching group"
" sizes.",
UserWarning,
stacklevel=2,
)
return model
model_config = ConfigDict(extra="forbid")
"""
Pre-Set Quantization Scheme Args
"""
def preset_name_to_scheme(name: str, targets: List[str]) -> QuantizationScheme:
"""
:param name: preset quantization settings name. must exist in upper case in
PRESET_SCHEMES
:param targets: list of quantization targets to be passed to the Scheme
:return: new QuantizationScheme for a given name with the given targets
"""
name = name.upper()
if name not in PRESET_SCHEMES:
raise KeyError(
f"Unknown preset scheme name {name}, "
f"available names: {list(PRESET_SCHEMES.keys())}"
)
scheme_args = deepcopy(PRESET_SCHEMES[name]) # deepcopy to avoid args references
return QuantizationScheme(
targets=targets,
**scheme_args,
)
def is_preset_scheme(name: str) -> bool:
"""
:param name: preset quantization settings name
:return: True if the name is a preset scheme name
"""
return name.upper() in PRESET_SCHEMES
UNQUANTIZED = dict()
NVFP4A16 = dict(
weights=QuantizationArgs(
num_bits=4,
type=QuantizationType.FLOAT,
strategy=QuantizationStrategy.TENSOR_GROUP,
symmetric=True,
dynamic=False,
group_size=16,
scale_dtype=FP8_E4M3_DATA.dtype,
zp_dtype=FP8_E4M3_DATA.dtype,
)
)
NVFP4 = dict(
weights=QuantizationArgs(
num_bits=4,
type=QuantizationType.FLOAT,
strategy=QuantizationStrategy.TENSOR_GROUP,
symmetric=True,
dynamic=False,
group_size=16,
observer="static_minmax",
scale_dtype=FP8_E4M3_DATA.dtype,
zp_dtype=FP8_E4M3_DATA.dtype,
),
input_activations=QuantizationArgs(
num_bits=4,
type=QuantizationType.FLOAT,
strategy=QuantizationStrategy.TENSOR_GROUP,
symmetric=True,
dynamic=DynamicType.LOCAL,
group_size=16,
observer="static_minmax",
scale_dtype=FP8_E4M3_DATA.dtype,
zp_dtype=FP8_E4M3_DATA.dtype,
),
)
MXFP4A16 = dict(
weights=QuantizationArgs(
num_bits=4,
type=QuantizationType.FLOAT,
strategy=QuantizationStrategy.GROUP,
symmetric=True,
dynamic=False,
group_size=32,
scale_dtype=torch.uint8,
zp_dtype=torch.uint8,
)
)
MXFP4 = dict(
weights=QuantizationArgs(
num_bits=4,
type=QuantizationType.FLOAT,
strategy=QuantizationStrategy.GROUP,
symmetric=True,
dynamic=False,
group_size=32,
scale_dtype=torch.uint8,
zp_dtype=torch.uint8,
),
input_activations=QuantizationArgs(
num_bits=4,
type=QuantizationType.FLOAT,
strategy=QuantizationStrategy.GROUP,
dynamic=True,
symmetric=True,
group_size=32,
scale_dtype=torch.uint8,
zp_dtype=torch.uint8,
),
)
# 8 bit integer weights and 8 bit activations quantization
INT8_W8A8 = dict(
weights=QuantizationArgs(
num_bits=8,
type=QuantizationType.INT,
strategy=QuantizationStrategy.CHANNEL,
symmetric=True,
dynamic=False,
),
input_activations=QuantizationArgs(
num_bits=8,
type=QuantizationType.INT,
strategy=QuantizationStrategy.TOKEN,
symmetric=True,
dynamic=True,
observer=None,
),
)
# 8 bit integer weights only quantization
W8A16 = dict(
weights=QuantizationArgs(
num_bits=8,
type=QuantizationType.INT,
strategy=QuantizationStrategy.CHANNEL,
symmetric=True,
dynamic=False,
),
)
# 4 bit integer weights only quantization
W4A16 = dict(
weights=QuantizationArgs(
num_bits=4,
type=QuantizationType.INT,
strategy=QuantizationStrategy.GROUP,
group_size=128,
symmetric=True,
dynamic=False,
),
)
# 4 bit integer weights only asymmetric quantization
W4A16_ASYM = dict(
weights=QuantizationArgs(
num_bits=4,
type=QuantizationType.INT,
strategy=QuantizationStrategy.GROUP,
group_size=128,
symmetric=False,
dynamic=False,
),
)
# 4 bit integer weights and 8 bit activations quantization
INT8_W4A8 = dict(
weights=QuantizationArgs(
num_bits=4,
type=QuantizationType.INT,
group_size=128,
strategy=QuantizationStrategy.GROUP,
symmetric=True,
dynamic=False,
),
input_activations=QuantizationArgs(
num_bits=8,
type=QuantizationType.INT,
strategy=QuantizationStrategy.TOKEN,
symmetric=True,
dynamic=True,
observer=None,
),
)
# FP8 weights and FP8 activations quantization
FP8 = dict(
weights=QuantizationArgs(
num_bits=8,
type=QuantizationType.FLOAT,
strategy=QuantizationStrategy.TENSOR,
symmetric=True,
dynamic=False,
),
input_activations=QuantizationArgs(
num_bits=8,
type=QuantizationType.FLOAT,
strategy=QuantizationStrategy.TENSOR,
symmetric=True,
dynamic=False,
),
)
# FP8 weights and FP8 dynamic activations quantization
FP8_DYNAMIC = dict(
weights=QuantizationArgs(
num_bits=8,
type=QuantizationType.FLOAT,
strategy=QuantizationStrategy.CHANNEL,
symmetric=True,
dynamic=False,
),
input_activations=QuantizationArgs(
num_bits=8,
type=QuantizationType.FLOAT,
strategy=QuantizationStrategy.TOKEN,
symmetric=True,
dynamic=True,
observer=None,
),
)
# Block‐wise FP8 (deepseekv3-style quantization):
# static 128x128 per‐block weights and
# dynamic per‐token‐group activations
FP8_BLOCK = dict(
weights=QuantizationArgs(
num_bits=8,
type=QuantizationType.FLOAT,
strategy=QuantizationStrategy.BLOCK,
symmetric=True,
dynamic=False,
block_structure=[128, 128],
),
input_activations=QuantizationArgs(
num_bits=8,
type=QuantizationType.FLOAT,
strategy=QuantizationStrategy.GROUP,
symmetric=True,
dynamic=True,
observer=None,
group_size=128,
),
)
PRESET_SCHEMES = {
# Unquantized (no-op)
"UNQUANTIZED": UNQUANTIZED,
# Integer weight only schemes
"W8A16": W8A16,
"W4A16": W4A16,
"W4A16_ASYM": W4A16_ASYM,
# Integer weight and activation schemes
"W8A8": INT8_W8A8,
"INT8": INT8_W8A8, # alias for W8A8
"W4A8": INT8_W4A8,
# Float weight and activation schemes
"FP8": FP8,
"FP8_DYNAMIC": FP8_DYNAMIC,
"FP8_BLOCK": FP8_BLOCK,
"NVFP4A16": NVFP4A16,
"NVFP4": NVFP4,
"MXFP4A16": MXFP4A16,
"MXFP4": MXFP4,
}