| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from functools import wraps |
| | from math import ceil |
| | from typing import Optional |
| |
|
| | import torch |
| | from compressed_tensors.quantization.quant_args import ( |
| | DynamicType, |
| | QuantizationArgs, |
| | QuantizationStrategy, |
| | round_to_quantized_type_args, |
| | ) |
| | from compressed_tensors.quantization.quant_config import QuantizationStatus |
| | from compressed_tensors.quantization.quant_scheme import QuantizationScheme |
| | from compressed_tensors.quantization.utils import ( |
| | calculate_range, |
| | compute_dynamic_scales_and_zp, |
| | ) |
| | from torch.nn import Module |
| |
|
| |
|
| | __all__ = [ |
| | "quantize", |
| | "dequantize", |
| | "fake_quantize", |
| | "wrap_module_forward_quantized", |
| | "forward_quantize", |
| | ] |
| |
|
| |
|
| | @torch.no_grad() |
| | def quantize( |
| | x: torch.Tensor, |
| | scale: torch.Tensor, |
| | zero_point: torch.Tensor, |
| | args: QuantizationArgs, |
| | dtype: Optional[torch.dtype] = None, |
| | g_idx: Optional[torch.Tensor] = None, |
| | global_scale: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| | """ |
| | Quantize the input tensor x using the QuantizationStrategy specified in args. |
| | Quantization can be done per tensor, channel, token or group. For group |
| | quantization, the group_size must be divisible by the column size. The input scale |
| | and zero_points are reshaped to support vectorization (Assumes 1 is the |
| | channel dimension) |
| | |
| | :param x: Input tensor |
| | :param scale: scale tensor |
| | :param zero_point: zero point tensor |
| | :param args: quantization args dictating how to quantize x |
| | :param dtype: optional dtype to cast the quantized output to |
| | :param g_idx: optional mapping from column index to group index |
| | :param global_scale: optional constant to scale the quantization scale during QDQ |
| | :return: fake quantized tensor |
| | """ |
| |
|
| | return _process_quantization( |
| | x=x, |
| | scale=scale, |
| | zero_point=zero_point, |
| | args=args, |
| | dtype=dtype, |
| | do_quantize=True, |
| | do_dequantize=False, |
| | g_idx=g_idx, |
| | global_scale=global_scale, |
| | ) |
| |
|
| |
|
| | @torch.no_grad() |
| | def dequantize( |
| | x_q: torch.Tensor, |
| | scale: torch.Tensor, |
| | zero_point: Optional[torch.Tensor] = None, |
| | args: Optional[QuantizationArgs] = None, |
| | dtype: Optional[torch.dtype] = None, |
| | g_idx: Optional[torch.Tensor] = None, |
| | global_scale: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| | """ |
| | Dequantize a quantized input tensor x_q based on the strategy specified in args. If |
| | args is not provided, the strategy will be inferred. |
| | |
| | :param x: quantized input tensor |
| | :param scale: scale tensor |
| | :param zero_point: zero point tensor |
| | :param args: quantization args used to quantize x_q |
| | :param dtype: optional dtype to cast the dequantized output to |
| | :param g_idx: optional mapping from column index to group index |
| | :param global_scale: optional constant to scale the quantization scale during QDQ |
| | :return: dequantized float tensor |
| | """ |
| | if args is None: |
| | if scale.ndim == 0 or scale.ndim == 1: |
| | args = QuantizationArgs(strategy=QuantizationStrategy.TENSOR) |
| | elif scale.ndim == 2: |
| | if scale.shape[1] == 1: |
| | args = QuantizationArgs(strategy=QuantizationStrategy.CHANNEL) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | elif (scale.shape[0] == 1) or (scale.shape[0] == x_q.shape[0]): |
| | group_size = int(x_q.shape[1] / scale.shape[1]) |
| | args = QuantizationArgs( |
| | strategy=QuantizationStrategy.GROUP, group_size=group_size |
| | ) |
| | else: |
| | rows, cols = x_q.shape[-2], x_q.shape[-1] |
| | block_height = rows // scale.shape[0] |
| | block_width = cols // scale.shape[1] |
| |
|
| | args = QuantizationArgs( |
| | strategy=QuantizationStrategy.BLOCK, |
| | block_structure=[block_height, block_width], |
| | ) |
| | else: |
| | raise ValueError( |
| | f"Could not infer a quantization strategy from scale with {scale.ndim} " |
| | "dimmensions. Expected 0 or 2 dimmensions." |
| | ) |
| |
|
| | if dtype is None: |
| | dtype = scale.dtype |
| |
|
| | return _process_quantization( |
| | x=x_q, |
| | scale=scale, |
| | zero_point=zero_point, |
| | args=args, |
| | do_quantize=False, |
| | do_dequantize=True, |
| | dtype=dtype, |
| | g_idx=g_idx, |
| | global_scale=global_scale, |
| | ) |
| |
|
| |
|
| | @torch.no_grad() |
| | def fake_quantize( |
| | x: torch.Tensor, |
| | scale: torch.Tensor, |
| | zero_point: torch.Tensor, |
| | args: QuantizationArgs, |
| | g_idx: Optional[torch.Tensor] = None, |
| | global_scale: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| | """ |
| | Fake quantize the input tensor x by quantizing then dequantizing with |
| | the QuantizationStrategy specified in args. Quantization can be done per tensor, |
| | channel, token or group. For group quantization, the group_size must be divisible |
| | by the column size. The input scale and zero_points are reshaped to support |
| | vectorization (Assumes 1 is the channel dimension) |
| | |
| | :param x: Input tensor |
| | :param scale: scale tensor |
| | :param zero_point: zero point tensor |
| | :param args: quantization args dictating how to quantize x |
| | :param g_idx: optional mapping from column index to group index |
| | :param global_scale: optional constant to scale the quantization scale during QDQ |
| | :return: fake quantized tensor |
| | """ |
| | return _process_quantization( |
| | x=x, |
| | scale=scale, |
| | zero_point=zero_point, |
| | args=args, |
| | do_quantize=True, |
| | do_dequantize=True, |
| | g_idx=g_idx, |
| | global_scale=global_scale, |
| | ) |
| |
|
| |
|
| | @torch.no_grad() |
| | def _process_quantization( |
| | x: torch.Tensor, |
| | scale: torch.Tensor, |
| | zero_point: torch.Tensor, |
| | args: QuantizationArgs, |
| | g_idx: Optional[torch.Tensor] = None, |
| | dtype: Optional[torch.dtype] = None, |
| | do_quantize: bool = True, |
| | do_dequantize: bool = True, |
| | global_scale: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| | q_min, q_max = calculate_range(args, x.device) |
| | group_size = args.group_size |
| |
|
| | |
| | |
| | if args.strategy == QuantizationStrategy.BLOCK: |
| | original_shape = x.shape |
| | rows, cols = x.shape[-2], x.shape[-1] |
| | block_height, block_width = args.block_structure |
| |
|
| | |
| | if rows % block_height != 0: |
| | raise ValueError( |
| | f"Tensor height {rows} is not divisible by block_height {block_height}." |
| | f" Block quantization requires exact division." |
| | ) |
| | if cols % block_width != 0: |
| | raise ValueError( |
| | f"Tensor width {cols} is not divisible by block_width {block_width}. " |
| | f"Block quantization requires exact division." |
| | ) |
| |
|
| | |
| | num_rows_blocks = rows // block_height |
| | num_cols_blocks = cols // block_width |
| | x_blocks = x.reshape( |
| | num_rows_blocks, |
| | block_height, |
| | num_cols_blocks, |
| | block_width, |
| | ).transpose(1, 2) |
| |
|
| | |
| | sb = scale.unsqueeze(-1).unsqueeze(-1) |
| | zb = zero_point.unsqueeze(-1).unsqueeze(-1) if zero_point is not None else None |
| | if do_quantize: |
| | |
| | x_blocks = _quantize( |
| | x=x_blocks, |
| | scale=sb, |
| | zero_point=zb, |
| | q_min=q_min, |
| | q_max=q_max, |
| | args=args, |
| | dtype=dtype, |
| | global_scale=global_scale, |
| | ) |
| | if do_dequantize: |
| | |
| | x_blocks = _dequantize( |
| | x_q=x_blocks, |
| | scale=sb, |
| | zero_point=zb, |
| | global_scale=global_scale, |
| | ) |
| | |
| | output = x_blocks.transpose(1, 2).reshape(original_shape) |
| | elif args.strategy in ( |
| | QuantizationStrategy.GROUP, |
| | QuantizationStrategy.TENSOR_GROUP, |
| | ): |
| |
|
| | output_dtype = dtype if dtype is not None else x.dtype |
| | output = torch.zeros_like(x).to(output_dtype) |
| | columns = output.shape[-1] |
| |
|
| | |
| |
|
| | while scale.ndim < 2: |
| | |
| | scale = scale.unsqueeze(1) |
| | zero_point = zero_point.unsqueeze(1) if zero_point is not None else None |
| |
|
| | if columns >= group_size: |
| | if columns % group_size != 0: |
| | raise ValueError( |
| | "tensor column shape must be divisble " |
| | f"by the given group_size {group_size} but got {columns}" |
| | ) |
| |
|
| | |
| | |
| | is_column_order = g_idx is None or -1 in g_idx |
| | if is_column_order: |
| | num_groups = int(ceil(columns / group_size)) |
| | group_sizes = torch.full((num_groups,), group_size, dtype=torch.int) |
| |
|
| | else: |
| | group_indices, group_sizes = torch.unique(g_idx, return_counts=True) |
| | group_sizes = group_sizes[torch.argsort(group_indices)] |
| |
|
| | perm = torch.argsort(g_idx) |
| | x = x.index_select(-1, perm) |
| |
|
| | |
| | reshaped_dims = ( |
| | ceil(x.shape[-1] / group_size), |
| | group_size, |
| | ) |
| | x = x.unflatten(-1, reshaped_dims) |
| |
|
| | if do_quantize: |
| | output = _quantize( |
| | x=x, |
| | scale=scale.unsqueeze(-1), |
| | zero_point=zero_point.unsqueeze(-1) if zero_point is not None else None, |
| | dtype=dtype, |
| | global_scale=global_scale, |
| | q_min=q_min, |
| | q_max=q_max, |
| | args=args, |
| | ) |
| |
|
| | if do_dequantize: |
| | input = output if do_quantize else x |
| | output = _dequantize( |
| | x_q=input, |
| | scale=scale.unsqueeze(-1), |
| | zero_point=zero_point.unsqueeze(-1) if zero_point is not None else None, |
| | global_scale=global_scale, |
| | ) |
| |
|
| | output = output.flatten(start_dim=-2) |
| | output = output.to(output_dtype) |
| |
|
| | if not is_column_order: |
| | inv_perm = torch.argsort(perm) |
| | output = output.index_select(-1, inv_perm) |
| |
|
| | else: |
| | if do_quantize: |
| | output = _quantize( |
| | x=x, |
| | scale=scale, |
| | zero_point=zero_point, |
| | q_min=q_min, |
| | q_max=q_max, |
| | args=args, |
| | dtype=dtype, |
| | global_scale=global_scale, |
| | ) |
| | if do_dequantize: |
| | output = _dequantize( |
| | output if do_quantize else x, |
| | scale=scale, |
| | zero_point=zero_point, |
| | global_scale=global_scale, |
| | ) |
| |
|
| | return output |
| |
|
| |
|
| | def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme): |
| | |
| | |
| | if hasattr(module.forward, "__func__"): |
| | forward_func_orig = module.forward.__func__ |
| | else: |
| | forward_func_orig = module.forward.func |
| |
|
| | @wraps(forward_func_orig) |
| | def wrapped_forward(self, *args, **kwargs): |
| | if not getattr(module, "quantization_enabled", True): |
| | |
| | |
| | return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs) |
| |
|
| | input_ = args[0] |
| |
|
| | compressed = module.quantization_status == QuantizationStatus.COMPRESSED |
| |
|
| | if scheme.input_activations is not None: |
| | |
| | input_ = forward_quantize(module, input_, "input", scheme.input_activations) |
| |
|
| | if scheme.weights is not None and not compressed: |
| | |
| | unquantized_weight = self.weight.data.clone() |
| | self.weight.data = forward_quantize( |
| | module, self.weight, "weight", scheme.weights |
| | ) |
| |
|
| | |
| | output = forward_func_orig.__get__(module, module.__class__)( |
| | input_, *args[1:], **kwargs |
| | ) |
| |
|
| | |
| | if scheme.weights is not None and not compressed: |
| | self.weight.data = unquantized_weight |
| |
|
| | if scheme.output_activations is not None: |
| | |
| | if ( |
| | module.quantization_status == QuantizationStatus.CALIBRATION |
| | and not scheme.output_activations.dynamic |
| | ): |
| | return output |
| |
|
| | output = forward_quantize( |
| | module, output, "output", scheme.output_activations |
| | ) |
| | return output |
| |
|
| | |
| | bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__) |
| | |
| | setattr(module, "forward", bound_wrapped_forward) |
| |
|
| |
|
| | def forward_quantize( |
| | module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs" |
| | ) -> torch.Tensor: |
| |
|
| | |
| | |
| | if ( |
| | module.quantization_status == QuantizationStatus.COMPRESSED |
| | and base_name == "weight" |
| | ): |
| | return value |
| |
|
| | if value.numel() == 0: |
| | |
| | |
| | return value |
| |
|
| | g_idx = getattr(module, "weight_g_idx", None) |
| | global_scale = getattr(module, f"{base_name}_global_scale", None) |
| |
|
| | if args.dynamic in (True, DynamicType.LOCAL): |
| | |
| | scale, zero_point = compute_dynamic_scales_and_zp( |
| | value=value, args=args, module=module, global_scale=global_scale |
| | ) |
| | else: |
| | |
| | scale = getattr(module, f"{base_name}_scale") |
| | zero_point = getattr(module, f"{base_name}_zero_point", None) |
| |
|
| | return fake_quantize( |
| | x=value, |
| | scale=scale, |
| | zero_point=zero_point, |
| | args=args, |
| | g_idx=g_idx, |
| | global_scale=global_scale, |
| | ) |
| |
|
| |
|
| | @torch.no_grad() |
| | def _quantize( |
| | x: torch.Tensor, |
| | scale: torch.Tensor, |
| | zero_point: torch.Tensor, |
| | q_min: torch.Tensor, |
| | q_max: torch.Tensor, |
| | args: QuantizationArgs, |
| | dtype: Optional[torch.dtype] = None, |
| | global_scale: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| |
|
| | |
| | |
| | if global_scale is not None: |
| | scale = scale / global_scale |
| |
|
| | scaled = x / scale |
| |
|
| | if zero_point is not None: |
| | scaled += zero_point.to(x.dtype) |
| |
|
| | |
| | quantized_value = round_to_quantized_type_args( |
| | tensor=scaled, args=args, min=q_min, max=q_max |
| | ) |
| |
|
| | if dtype is not None: |
| | quantized_value = quantized_value.to(dtype) |
| |
|
| | return quantized_value |
| |
|
| |
|
| | @torch.no_grad() |
| | def _dequantize( |
| | x_q: torch.Tensor, |
| | scale: torch.Tensor, |
| | zero_point: torch.Tensor = None, |
| | dtype: Optional[torch.dtype] = None, |
| | global_scale: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| |
|
| | |
| | |
| | if global_scale is not None: |
| | scale = scale / global_scale |
| |
|
| | dequant_value = x_q.to(scale.dtype) |
| |
|
| | if zero_point is not None: |
| | dequant_value = dequant_value - zero_point.to(scale.dtype) |
| |
|
| | dequant_value = dequant_value * scale |
| |
|
| | if dtype is not None: |
| | dequant_value = dequant_value.to(dtype) |
| |
|
| | return dequant_value |
| |
|