| | from __future__ import annotations |
| |
|
| | from dataclasses import dataclass |
| | from typing import Dict |
| |
|
| | import torch |
| | from torch import Tensor, nn |
| |
|
| | from .config import QuantizationConfig |
| | from .observers import ObserverState |
| |
|
| |
|
| | @dataclass |
| | class SmoothQuantWeights: |
| | """Quantized representation for SmoothQuant linear layers.""" |
| |
|
| | weight: Tensor |
| | weight_scales: Tensor |
| | input_scale: Tensor |
| | activation_scale: Tensor |
| |
|
| |
|
| | def _prepare_stats( |
| | observer_state: ObserverState, |
| | weight: Tensor, |
| | epsilon: float, |
| | ) -> tuple[Tensor, Tensor]: |
| | activation_stats = observer_state.max_abs_values.to(dtype=torch.float32) |
| | if activation_stats.numel() < weight.size(1): |
| | activation_stats = torch.nn.functional.pad( |
| | activation_stats, |
| | (0, weight.size(1) - activation_stats.numel()), |
| | value=1.0, |
| | ) |
| | activation_stats = activation_stats[: weight.size(1)] |
| | activation_stats = activation_stats.clamp_min(epsilon) |
| |
|
| | weight_stats = weight.abs().amax(dim=0).clamp_min(epsilon) |
| | return activation_stats, weight_stats |
| |
|
| |
|
| | def quantize_linear_smooth( |
| | module: nn.Linear, |
| | observer_state: ObserverState, |
| | config: QuantizationConfig, |
| | ) -> SmoothQuantWeights: |
| | """ |
| | Apply SmoothQuant to a linear layer, producing int quantized weights and activation scales. |
| | """ |
| |
|
| | weight_bits = config.weight_bits |
| | activation_bits = config.activation_bits |
| | epsilon = config.epsilon |
| | alpha = config.alpha |
| | quant_dtype = torch.int8 if weight_bits <= 8 else torch.int16 |
| |
|
| | weight = module.weight.detach().to(torch.float32).clone() |
| | activation_stats, weight_stats = _prepare_stats(observer_state, weight, epsilon) |
| |
|
| | ratio = activation_stats / weight_stats |
| | smoothing_factor = torch.pow(ratio, alpha).clamp_min(epsilon) |
| |
|
| | input_scale = (1.0 / smoothing_factor).to(torch.float32) |
| | scaled_weight = weight * smoothing_factor.unsqueeze(0) |
| |
|
| | act_max_scaled = activation_stats * input_scale |
| | act_qmax = (2 ** (activation_bits - 1)) - 1 |
| | activation_scale = (act_max_scaled / act_qmax).clamp_min(epsilon) |
| |
|
| | weight_qmax = (2 ** (weight_bits - 1)) - 1 |
| | weight_max = scaled_weight.abs().amax(dim=1).clamp_min(epsilon) |
| | weight_scales = (weight_max / weight_qmax).unsqueeze(1) |
| |
|
| | quantized_weight = torch.round(scaled_weight / weight_scales).clamp( |
| | -(2 ** (weight_bits - 1)), weight_qmax |
| | ).to(quant_dtype) |
| |
|
| | return SmoothQuantWeights( |
| | weight=quantized_weight.cpu(), |
| | weight_scales=weight_scales.to(torch.float32).cpu(), |
| | input_scale=input_scale.cpu(), |
| | activation_scale=activation_scale.cpu(), |
| | ) |
| |
|
| |
|
| | def summarize_smoothquant( |
| | stats: Dict[str, SmoothQuantWeights] |
| | ) -> Dict[str, Dict[str, float]]: |
| | summary: Dict[str, Dict[str, float]] = {} |
| | for name, record in stats.items(): |
| | summary[name] = { |
| | "weight_scale_mean": float(record.weight_scales.mean()), |
| | "activation_scale_mean": float(record.activation_scale.mean()), |
| | } |
| | return summary |
| |
|