from __future__ import annotations from dataclasses import dataclass from typing import Dict import torch from torch import Tensor, nn from .config import QuantizationConfig from .observers import ObserverState @dataclass class SmoothQuantWeights: """Quantized representation for SmoothQuant linear layers.""" weight: Tensor weight_scales: Tensor input_scale: Tensor activation_scale: Tensor def _prepare_stats( observer_state: ObserverState, weight: Tensor, epsilon: float, ) -> tuple[Tensor, Tensor]: activation_stats = observer_state.max_abs_values.to(dtype=torch.float32) if activation_stats.numel() < weight.size(1): activation_stats = torch.nn.functional.pad( activation_stats, (0, weight.size(1) - activation_stats.numel()), value=1.0, ) activation_stats = activation_stats[: weight.size(1)] activation_stats = activation_stats.clamp_min(epsilon) weight_stats = weight.abs().amax(dim=0).clamp_min(epsilon) return activation_stats, weight_stats def quantize_linear_smooth( module: nn.Linear, observer_state: ObserverState, config: QuantizationConfig, ) -> SmoothQuantWeights: """ Apply SmoothQuant to a linear layer, producing int quantized weights and activation scales. """ weight_bits = config.weight_bits activation_bits = config.activation_bits epsilon = config.epsilon alpha = config.alpha quant_dtype = torch.int8 if weight_bits <= 8 else torch.int16 weight = module.weight.detach().to(torch.float32).clone() activation_stats, weight_stats = _prepare_stats(observer_state, weight, epsilon) ratio = activation_stats / weight_stats smoothing_factor = torch.pow(ratio, alpha).clamp_min(epsilon) input_scale = (1.0 / smoothing_factor).to(torch.float32) scaled_weight = weight * smoothing_factor.unsqueeze(0) act_max_scaled = activation_stats * input_scale act_qmax = (2 ** (activation_bits - 1)) - 1 activation_scale = (act_max_scaled / act_qmax).clamp_min(epsilon) weight_qmax = (2 ** (weight_bits - 1)) - 1 weight_max = scaled_weight.abs().amax(dim=1).clamp_min(epsilon) weight_scales = (weight_max / weight_qmax).unsqueeze(1) quantized_weight = torch.round(scaled_weight / weight_scales).clamp( -(2 ** (weight_bits - 1)), weight_qmax ).to(quant_dtype) return SmoothQuantWeights( weight=quantized_weight.cpu(), weight_scales=weight_scales.to(torch.float32).cpu(), input_scale=input_scale.cpu(), activation_scale=activation_scale.cpu(), ) def summarize_smoothquant( stats: Dict[str, SmoothQuantWeights] ) -> Dict[str, Dict[str, float]]: summary: Dict[str, Dict[str, float]] = {} for name, record in stats.items(): summary[name] = { "weight_scale_mean": float(record.weight_scales.mean()), "activation_scale_mean": float(record.activation_scale.mean()), } return summary