Harmony18090's picture
Add source batch 2/11
76f9669 verified
raw
history blame
17.4 kB
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import wraps
from math import ceil
from typing import Optional
import torch
from compressed_tensors.quantization.quant_args import (
DynamicType,
QuantizationArgs,
QuantizationStrategy,
round_to_quantized_type_args,
)
from compressed_tensors.quantization.quant_config import QuantizationStatus
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from compressed_tensors.quantization.utils import (
calculate_range,
compute_dynamic_scales_and_zp,
)
from torch.nn import Module
__all__ = [
"quantize",
"dequantize",
"fake_quantize",
"wrap_module_forward_quantized",
"forward_quantize",
]
@torch.no_grad()
def quantize(
x: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
args: QuantizationArgs,
dtype: Optional[torch.dtype] = None,
g_idx: Optional[torch.Tensor] = None,
global_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Quantize the input tensor x using the QuantizationStrategy specified in args.
Quantization can be done per tensor, channel, token or group. For group
quantization, the group_size must be divisible by the column size. The input scale
and zero_points are reshaped to support vectorization (Assumes 1 is the
channel dimension)
:param x: Input tensor
:param scale: scale tensor
:param zero_point: zero point tensor
:param args: quantization args dictating how to quantize x
:param dtype: optional dtype to cast the quantized output to
:param g_idx: optional mapping from column index to group index
:param global_scale: optional constant to scale the quantization scale during QDQ
:return: fake quantized tensor
"""
return _process_quantization(
x=x,
scale=scale,
zero_point=zero_point,
args=args,
dtype=dtype,
do_quantize=True,
do_dequantize=False,
g_idx=g_idx,
global_scale=global_scale,
)
@torch.no_grad()
def dequantize(
x_q: torch.Tensor,
scale: torch.Tensor,
zero_point: Optional[torch.Tensor] = None,
args: Optional[QuantizationArgs] = None,
dtype: Optional[torch.dtype] = None,
g_idx: Optional[torch.Tensor] = None,
global_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Dequantize a quantized input tensor x_q based on the strategy specified in args. If
args is not provided, the strategy will be inferred.
:param x: quantized input tensor
:param scale: scale tensor
:param zero_point: zero point tensor
:param args: quantization args used to quantize x_q
:param dtype: optional dtype to cast the dequantized output to
:param g_idx: optional mapping from column index to group index
:param global_scale: optional constant to scale the quantization scale during QDQ
:return: dequantized float tensor
"""
if args is None:
if scale.ndim == 0 or scale.ndim == 1:
args = QuantizationArgs(strategy=QuantizationStrategy.TENSOR)
elif scale.ndim == 2:
if scale.shape[1] == 1:
args = QuantizationArgs(strategy=QuantizationStrategy.CHANNEL)
# Scale height matches input or is 1 -> group quantization across columns
#
# Example 1: scale.shape[0] == 1
# x_q: (4, 8), scale: (1, 4) -> 2 columns per group
#
# Example 2: scale.shape[0] == x_q.shape[0]
# x_q: (4, 8), scale: (4, 4) -> 2 elements per group (per row)
elif (scale.shape[0] == 1) or (scale.shape[0] == x_q.shape[0]):
group_size = int(x_q.shape[1] / scale.shape[1])
args = QuantizationArgs(
strategy=QuantizationStrategy.GROUP, group_size=group_size
)
else:
rows, cols = x_q.shape[-2], x_q.shape[-1]
block_height = rows // scale.shape[0] # Rows per block
block_width = cols // scale.shape[1] # Columns per block
args = QuantizationArgs(
strategy=QuantizationStrategy.BLOCK,
block_structure=[block_height, block_width],
)
else:
raise ValueError(
f"Could not infer a quantization strategy from scale with {scale.ndim} "
"dimmensions. Expected 0 or 2 dimmensions."
)
if dtype is None:
dtype = scale.dtype
return _process_quantization(
x=x_q,
scale=scale,
zero_point=zero_point,
args=args,
do_quantize=False,
do_dequantize=True,
dtype=dtype,
g_idx=g_idx,
global_scale=global_scale,
)
@torch.no_grad()
def fake_quantize(
x: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
args: QuantizationArgs,
g_idx: Optional[torch.Tensor] = None,
global_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Fake quantize the input tensor x by quantizing then dequantizing with
the QuantizationStrategy specified in args. Quantization can be done per tensor,
channel, token or group. For group quantization, the group_size must be divisible
by the column size. The input scale and zero_points are reshaped to support
vectorization (Assumes 1 is the channel dimension)
:param x: Input tensor
:param scale: scale tensor
:param zero_point: zero point tensor
:param args: quantization args dictating how to quantize x
:param g_idx: optional mapping from column index to group index
:param global_scale: optional constant to scale the quantization scale during QDQ
:return: fake quantized tensor
"""
return _process_quantization(
x=x,
scale=scale,
zero_point=zero_point,
args=args,
do_quantize=True,
do_dequantize=True,
g_idx=g_idx,
global_scale=global_scale,
)
@torch.no_grad()
def _process_quantization(
x: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
args: QuantizationArgs,
g_idx: Optional[torch.Tensor] = None,
dtype: Optional[torch.dtype] = None,
do_quantize: bool = True,
do_dequantize: bool = True,
global_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
q_min, q_max = calculate_range(args, x.device)
group_size = args.group_size
# blockwise FP8: quantize per 2D block, supports block_structure for static block
# quantization
if args.strategy == QuantizationStrategy.BLOCK:
original_shape = x.shape
rows, cols = x.shape[-2], x.shape[-1]
block_height, block_width = args.block_structure
# Ensure exact division (tensor dimensions must be divisible by block size)
if rows % block_height != 0:
raise ValueError(
f"Tensor height {rows} is not divisible by block_height {block_height}."
f" Block quantization requires exact division."
)
if cols % block_width != 0:
raise ValueError(
f"Tensor width {cols} is not divisible by block_width {block_width}. "
f"Block quantization requires exact division."
)
# reshape into blocks and transpose to make each block contiguous
num_rows_blocks = rows // block_height
num_cols_blocks = cols // block_width
x_blocks = x.reshape(
num_rows_blocks,
block_height,
num_cols_blocks,
block_width,
).transpose(1, 2)
# expand scale/zero_point for blocks
sb = scale.unsqueeze(-1).unsqueeze(-1)
zb = zero_point.unsqueeze(-1).unsqueeze(-1) if zero_point is not None else None
if do_quantize:
# quantize blocks
x_blocks = _quantize(
x=x_blocks,
scale=sb,
zero_point=zb,
q_min=q_min,
q_max=q_max,
args=args,
dtype=dtype,
global_scale=global_scale,
)
if do_dequantize:
# dequantize blocks
x_blocks = _dequantize(
x_q=x_blocks,
scale=sb,
zero_point=zb,
global_scale=global_scale,
)
# restore original shape
output = x_blocks.transpose(1, 2).reshape(original_shape)
elif args.strategy in (
QuantizationStrategy.GROUP,
QuantizationStrategy.TENSOR_GROUP,
):
output_dtype = dtype if dtype is not None else x.dtype
output = torch.zeros_like(x).to(output_dtype)
columns = output.shape[-1]
# TODO: make validation step for inputs
while scale.ndim < 2:
# pad scale and zero point dims for slicing
scale = scale.unsqueeze(1)
zero_point = zero_point.unsqueeze(1) if zero_point is not None else None
if columns >= group_size:
if columns % group_size != 0:
raise ValueError(
"tensor column shape must be divisble "
f"by the given group_size {group_size} but got {columns}"
)
# support column-order (default) quantization as well as other orderings
# such as activation ordering. Below checks if g_idx has been initialized
is_column_order = g_idx is None or -1 in g_idx
if is_column_order:
num_groups = int(ceil(columns / group_size))
group_sizes = torch.full((num_groups,), group_size, dtype=torch.int)
else:
group_indices, group_sizes = torch.unique(g_idx, return_counts=True)
group_sizes = group_sizes[torch.argsort(group_indices)]
perm = torch.argsort(g_idx)
x = x.index_select(-1, perm)
# Maintain all dimensions except the last dim, which is divided by group_size
reshaped_dims = (
ceil(x.shape[-1] / group_size),
group_size,
)
x = x.unflatten(-1, reshaped_dims)
if do_quantize:
output = _quantize(
x=x,
scale=scale.unsqueeze(-1),
zero_point=zero_point.unsqueeze(-1) if zero_point is not None else None,
dtype=dtype,
global_scale=global_scale,
q_min=q_min,
q_max=q_max,
args=args,
)
if do_dequantize:
input = output if do_quantize else x
output = _dequantize(
x_q=input,
scale=scale.unsqueeze(-1),
zero_point=zero_point.unsqueeze(-1) if zero_point is not None else None,
global_scale=global_scale,
)
output = output.flatten(start_dim=-2)
output = output.to(output_dtype)
if not is_column_order:
inv_perm = torch.argsort(perm)
output = output.index_select(-1, inv_perm)
else: # covers tensor, channel, token, and attn_head strategies
if do_quantize:
output = _quantize(
x=x,
scale=scale,
zero_point=zero_point,
q_min=q_min,
q_max=q_max,
args=args,
dtype=dtype,
global_scale=global_scale,
)
if do_dequantize:
output = _dequantize(
output if do_quantize else x,
scale=scale,
zero_point=zero_point,
global_scale=global_scale,
)
return output
def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
# expects a module already initialized and injected with the parameters in
# initialize_module_for_quantization
if hasattr(module.forward, "__func__"):
forward_func_orig = module.forward.__func__
else:
forward_func_orig = module.forward.func
@wraps(forward_func_orig) # ensures docstring, names, etc are propagated
def wrapped_forward(self, *args, **kwargs):
if not getattr(module, "quantization_enabled", True):
# quantization is disabled on forward passes, return baseline
# forward call
return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs)
input_ = args[0]
compressed = module.quantization_status == QuantizationStatus.COMPRESSED
if scheme.input_activations is not None:
# prehook should calibrate activations before forward call
input_ = forward_quantize(module, input_, "input", scheme.input_activations)
if scheme.weights is not None and not compressed:
# calibrate and (fake) quantize weights when applicable
unquantized_weight = self.weight.data.clone()
self.weight.data = forward_quantize(
module, self.weight, "weight", scheme.weights
)
# perform wrapped forward call
output = forward_func_orig.__get__(module, module.__class__)(
input_, *args[1:], **kwargs
)
# restore back to unquantized_value
if scheme.weights is not None and not compressed:
self.weight.data = unquantized_weight
if scheme.output_activations is not None:
# forward-hook should calibrate/forward_quantize
if (
module.quantization_status == QuantizationStatus.CALIBRATION
and not scheme.output_activations.dynamic
):
return output
output = forward_quantize(
module, output, "output", scheme.output_activations
)
return output
# bind wrapped forward to module class so reference to `self` is correct
bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__)
# set forward to wrapped forward
setattr(module, "forward", bound_wrapped_forward)
def forward_quantize(
module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs"
) -> torch.Tensor:
# in compressed mode, the weight is already compressed and quantized so we don't
# need to run fake quantization
if (
module.quantization_status == QuantizationStatus.COMPRESSED
and base_name == "weight"
):
return value
if value.numel() == 0:
# if the tensor is empty,
# skip quantization
return value
g_idx = getattr(module, "weight_g_idx", None)
global_scale = getattr(module, f"{base_name}_global_scale", None)
if args.dynamic in (True, DynamicType.LOCAL):
# dynamic quantization - determine the scale/zp on the fly
scale, zero_point = compute_dynamic_scales_and_zp(
value=value, args=args, module=module, global_scale=global_scale
)
else:
# static quantization - get scale and zero point from layer
scale = getattr(module, f"{base_name}_scale")
zero_point = getattr(module, f"{base_name}_zero_point", None)
return fake_quantize(
x=value,
scale=scale,
zero_point=zero_point,
args=args,
g_idx=g_idx,
global_scale=global_scale,
)
@torch.no_grad()
def _quantize(
x: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
q_min: torch.Tensor,
q_max: torch.Tensor,
args: QuantizationArgs,
dtype: Optional[torch.dtype] = None,
global_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# if a global scale is optionally provided, use it
# to further scale the local `scale` parameter
if global_scale is not None:
scale = scale / global_scale
scaled = x / scale
if zero_point is not None:
scaled += zero_point.to(x.dtype)
# clamp and round
quantized_value = round_to_quantized_type_args(
tensor=scaled, args=args, min=q_min, max=q_max
)
if dtype is not None:
quantized_value = quantized_value.to(dtype)
return quantized_value
@torch.no_grad()
def _dequantize(
x_q: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor = None,
dtype: Optional[torch.dtype] = None,
global_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# if a global scale is optionally provided, use it
# to further scale the local `scale` parameter
if global_scale is not None:
scale = scale / global_scale
dequant_value = x_q.to(scale.dtype)
if zero_point is not None:
dequant_value = dequant_value - zero_point.to(scale.dtype)
dequant_value = dequant_value * scale
if dtype is not None:
dequant_value = dequant_value.to(dtype)
return dequant_value