from __future__ import annotations import math from typing import Optional import torch from torch import Tensor, nn def _select_quant_dtype(bits: int) -> torch.dtype: if bits <= 0: raise ValueError("Quantization bits must be positive.") if bits <= 8: return torch.int8 if bits <= 16: return torch.int16 raise ValueError("Quantization bits above 16 are not supported.") class QuantizedLinear(nn.Module): """Weight-only linear layer with per-group scales.""" def __init__( self, in_features: int, out_features: int, *, weight_bits: int = 4, group_size: int = 128, bias: bool = True, ) -> None: super().__init__() self.in_features = in_features self.out_features = out_features self.weight_bits = weight_bits self.group_size = group_size self.qmin = -(2 ** (weight_bits - 1)) self.qmax = (2 ** (weight_bits - 1)) - 1 self.num_groups = math.ceil(in_features / group_size) self.quant_dtype = _select_quant_dtype(weight_bits) weight_shape = (out_features, in_features) scale_shape = (out_features, self.num_groups) self.register_buffer("weight", torch.zeros(weight_shape, dtype=self.quant_dtype)) self.register_buffer( "weight_scales", torch.ones(scale_shape, dtype=torch.float32) ) self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None self._weight_cache: Optional[Tensor] = None def _invalidate_cache(self) -> None: self._weight_cache = None def refresh_weight_cache(self) -> None: self._weight_cache = self._dequantize_weight() def _dequantize_weight(self) -> Tensor: group_tensors = [] for group_idx in range(self.num_groups): start = group_idx * self.group_size end = min((group_idx + 1) * self.group_size, self.in_features) block = self.weight[:, start:end].float() scale = self.weight_scales[:, group_idx].unsqueeze(1) group_tensors.append(block * scale) return torch.cat(group_tensors, dim=1) def forward(self, input: Tensor) -> Tensor: if self._weight_cache is None or self._weight_cache.device != input.device: self.refresh_weight_cache() self._weight_cache = self._weight_cache.to(input.device) weight = self._weight_cache if weight.dtype != input.dtype: weight = weight.to(input.dtype) bias = self.bias if bias is not None and bias.device != input.device: bias = bias.to(input.device) if bias is not None and bias.dtype != input.dtype: bias = bias.to(input.dtype) return nn.functional.linear(input, weight, bias) def load_quant_state(self, weight: Tensor, weight_scales: Tensor) -> None: if weight.shape != self.weight.shape: raise ValueError( f"Quantized weight shape mismatch: expected {tuple(self.weight.shape)}, " f"got {tuple(weight.shape)}" ) if weight_scales.shape != self.weight_scales.shape: raise ValueError( f"Scale tensor shape mismatch: expected {tuple(self.weight_scales.shape)}, " f"got {tuple(weight_scales.shape)}" ) self.weight.copy_(weight.to(dtype=self.quant_dtype)) self.weight_scales.copy_(weight_scales.to(dtype=torch.float32)) self._invalidate_cache() def extra_repr(self) -> str: return ( f"in_features={self.in_features}, out_features={self.out_features}, " f"group_size={self.group_size}, bits={self.weight_bits}, bias={self.bias is not None}" ) class SmoothQuantLinear(nn.Module): """Linear layer with SmoothQuant W8A8 (or configurable) quantization.""" def __init__( self, in_features: int, out_features: int, *, weight_bits: int = 8, activation_bits: int = 8, bias: bool = True, ) -> None: super().__init__() if weight_bits <= 0 or weight_bits > 16: raise ValueError("Weight bits must be in range [1, 16].") if activation_bits <= 0 or activation_bits > 16: raise ValueError("Activation bits must be in range [1, 16].") self.in_features = in_features self.out_features = out_features self.weight_bits = weight_bits self.activation_bits = activation_bits self.weight_qmin = -(2 ** (weight_bits - 1)) self.weight_qmax = (2 ** (weight_bits - 1)) - 1 self.activation_qmin = -(2 ** (activation_bits - 1)) self.activation_qmax = (2 ** (activation_bits - 1)) - 1 self.quant_dtype = _select_quant_dtype(weight_bits) weight_shape = (out_features, in_features) self.register_buffer("weight", torch.zeros(weight_shape, dtype=self.quant_dtype)) self.register_buffer( "weight_scales", torch.ones(out_features, 1, dtype=torch.float32) ) self.register_buffer( "input_scale", torch.ones(in_features, dtype=torch.float32) ) self.register_buffer( "activation_scale", torch.ones(in_features, dtype=torch.float32) ) self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None self._weight_cache: Optional[Tensor] = None def _invalidate_cache(self) -> None: self._weight_cache = None def refresh_weight_cache(self) -> None: weight = self.weight.float() * self.weight_scales self._weight_cache = weight def forward(self, input: Tensor) -> Tensor: if self._weight_cache is None or self._weight_cache.device != input.device: self.refresh_weight_cache() self._weight_cache = self._weight_cache.to(input.device) activation_scale = self.activation_scale.to(input.device) input_scale = self.input_scale.to(input.device) scaled_input = input * input_scale quantized = torch.round(scaled_input / activation_scale).clamp( self.activation_qmin, self.activation_qmax ) dequant_input = quantized * activation_scale weight = self._weight_cache if weight.dtype != dequant_input.dtype: weight = weight.to(dequant_input.dtype) bias = self.bias if bias is not None and bias.device != input.device: bias = bias.to(input.device) if bias is not None and bias.dtype != dequant_input.dtype: bias = bias.to(dequant_input.dtype) return nn.functional.linear(dequant_input, weight, bias) def load_quant_state( self, weight: Tensor, weight_scales: Tensor, input_scale: Tensor, activation_scale: Tensor, ) -> None: if weight.shape != self.weight.shape: raise ValueError( f"Quantized weight shape mismatch: expected {tuple(self.weight.shape)}, " f"got {tuple(weight.shape)}" ) if weight_scales.shape != self.weight_scales.shape: raise ValueError( f"Weight scale shape mismatch: expected {tuple(self.weight_scales.shape)}, " f"got {tuple(weight_scales.shape)}" ) if input_scale.shape != self.input_scale.shape: raise ValueError( f"Input scale shape mismatch: expected {tuple(self.input_scale.shape)}, " f"got {tuple(input_scale.shape)}" ) if activation_scale.shape != self.activation_scale.shape: raise ValueError( f"Activation scale shape mismatch: expected {tuple(self.activation_scale.shape)}, " f"got {tuple(activation_scale.shape)}" ) self.weight.copy_(weight.to(dtype=self.quant_dtype)) self.weight_scales.copy_(weight_scales.to(dtype=torch.float32)) self.input_scale.copy_(input_scale.to(dtype=torch.float32)) self.activation_scale.copy_(activation_scale.to(dtype=torch.float32)) self._invalidate_cache() def extra_repr(self) -> str: return ( f"in_features={self.in_features}, out_features={self.out_features}, " f"weight_bits={self.weight_bits}, activation_bits={self.activation_bits}, " f"bias={self.bias is not None}" )