| from __future__ import annotations | |
| import math | |
| from typing import Optional | |
| import torch | |
| from torch import Tensor, nn | |
| def _select_quant_dtype(bits: int) -> torch.dtype: | |
| if bits <= 0: | |
| raise ValueError("Quantization bits must be positive.") | |
| if bits <= 8: | |
| return torch.int8 | |
| if bits <= 16: | |
| return torch.int16 | |
| raise ValueError("Quantization bits above 16 are not supported.") | |
| class QuantizedLinear(nn.Module): | |
| """Weight-only linear layer with per-group scales.""" | |
| def __init__( | |
| self, | |
| in_features: int, | |
| out_features: int, | |
| *, | |
| weight_bits: int = 4, | |
| group_size: int = 128, | |
| bias: bool = True, | |
| ) -> None: | |
| super().__init__() | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| self.weight_bits = weight_bits | |
| self.group_size = group_size | |
| self.qmin = -(2 ** (weight_bits - 1)) | |
| self.qmax = (2 ** (weight_bits - 1)) - 1 | |
| self.num_groups = math.ceil(in_features / group_size) | |
| self.quant_dtype = _select_quant_dtype(weight_bits) | |
| weight_shape = (out_features, in_features) | |
| scale_shape = (out_features, self.num_groups) | |
| self.register_buffer("weight", torch.zeros(weight_shape, dtype=self.quant_dtype)) | |
| self.register_buffer( | |
| "weight_scales", torch.ones(scale_shape, dtype=torch.float32) | |
| ) | |
| self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None | |
| self._weight_cache: Optional[Tensor] = None | |
| def _invalidate_cache(self) -> None: | |
| self._weight_cache = None | |
| def refresh_weight_cache(self) -> None: | |
| self._weight_cache = self._dequantize_weight() | |
| def _dequantize_weight(self) -> Tensor: | |
| group_tensors = [] | |
| for group_idx in range(self.num_groups): | |
| start = group_idx * self.group_size | |
| end = min((group_idx + 1) * self.group_size, self.in_features) | |
| block = self.weight[:, start:end].float() | |
| scale = self.weight_scales[:, group_idx].unsqueeze(1) | |
| group_tensors.append(block * scale) | |
| return torch.cat(group_tensors, dim=1) | |
| def forward(self, input: Tensor) -> Tensor: | |
| if self._weight_cache is None or self._weight_cache.device != input.device: | |
| self.refresh_weight_cache() | |
| self._weight_cache = self._weight_cache.to(input.device) | |
| weight = self._weight_cache | |
| if weight.dtype != input.dtype: | |
| weight = weight.to(input.dtype) | |
| bias = self.bias | |
| if bias is not None and bias.device != input.device: | |
| bias = bias.to(input.device) | |
| if bias is not None and bias.dtype != input.dtype: | |
| bias = bias.to(input.dtype) | |
| return nn.functional.linear(input, weight, bias) | |
| def load_quant_state(self, weight: Tensor, weight_scales: Tensor) -> None: | |
| if weight.shape != self.weight.shape: | |
| raise ValueError( | |
| f"Quantized weight shape mismatch: expected {tuple(self.weight.shape)}, " | |
| f"got {tuple(weight.shape)}" | |
| ) | |
| if weight_scales.shape != self.weight_scales.shape: | |
| raise ValueError( | |
| f"Scale tensor shape mismatch: expected {tuple(self.weight_scales.shape)}, " | |
| f"got {tuple(weight_scales.shape)}" | |
| ) | |
| self.weight.copy_(weight.to(dtype=self.quant_dtype)) | |
| self.weight_scales.copy_(weight_scales.to(dtype=torch.float32)) | |
| self._invalidate_cache() | |
| def extra_repr(self) -> str: | |
| return ( | |
| f"in_features={self.in_features}, out_features={self.out_features}, " | |
| f"group_size={self.group_size}, bits={self.weight_bits}, bias={self.bias is not None}" | |
| ) | |
| class SmoothQuantLinear(nn.Module): | |
| """Linear layer with SmoothQuant W8A8 (or configurable) quantization.""" | |
| def __init__( | |
| self, | |
| in_features: int, | |
| out_features: int, | |
| *, | |
| weight_bits: int = 8, | |
| activation_bits: int = 8, | |
| bias: bool = True, | |
| ) -> None: | |
| super().__init__() | |
| if weight_bits <= 0 or weight_bits > 16: | |
| raise ValueError("Weight bits must be in range [1, 16].") | |
| if activation_bits <= 0 or activation_bits > 16: | |
| raise ValueError("Activation bits must be in range [1, 16].") | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| self.weight_bits = weight_bits | |
| self.activation_bits = activation_bits | |
| self.weight_qmin = -(2 ** (weight_bits - 1)) | |
| self.weight_qmax = (2 ** (weight_bits - 1)) - 1 | |
| self.activation_qmin = -(2 ** (activation_bits - 1)) | |
| self.activation_qmax = (2 ** (activation_bits - 1)) - 1 | |
| self.quant_dtype = _select_quant_dtype(weight_bits) | |
| weight_shape = (out_features, in_features) | |
| self.register_buffer("weight", torch.zeros(weight_shape, dtype=self.quant_dtype)) | |
| self.register_buffer( | |
| "weight_scales", torch.ones(out_features, 1, dtype=torch.float32) | |
| ) | |
| self.register_buffer( | |
| "input_scale", torch.ones(in_features, dtype=torch.float32) | |
| ) | |
| self.register_buffer( | |
| "activation_scale", torch.ones(in_features, dtype=torch.float32) | |
| ) | |
| self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None | |
| self._weight_cache: Optional[Tensor] = None | |
| def _invalidate_cache(self) -> None: | |
| self._weight_cache = None | |
| def refresh_weight_cache(self) -> None: | |
| weight = self.weight.float() * self.weight_scales | |
| self._weight_cache = weight | |
| def forward(self, input: Tensor) -> Tensor: | |
| if self._weight_cache is None or self._weight_cache.device != input.device: | |
| self.refresh_weight_cache() | |
| self._weight_cache = self._weight_cache.to(input.device) | |
| activation_scale = self.activation_scale.to(input.device) | |
| input_scale = self.input_scale.to(input.device) | |
| scaled_input = input * input_scale | |
| quantized = torch.round(scaled_input / activation_scale).clamp( | |
| self.activation_qmin, self.activation_qmax | |
| ) | |
| dequant_input = quantized * activation_scale | |
| weight = self._weight_cache | |
| if weight.dtype != dequant_input.dtype: | |
| weight = weight.to(dequant_input.dtype) | |
| bias = self.bias | |
| if bias is not None and bias.device != input.device: | |
| bias = bias.to(input.device) | |
| if bias is not None and bias.dtype != dequant_input.dtype: | |
| bias = bias.to(dequant_input.dtype) | |
| return nn.functional.linear(dequant_input, weight, bias) | |
| def load_quant_state( | |
| self, | |
| weight: Tensor, | |
| weight_scales: Tensor, | |
| input_scale: Tensor, | |
| activation_scale: Tensor, | |
| ) -> None: | |
| if weight.shape != self.weight.shape: | |
| raise ValueError( | |
| f"Quantized weight shape mismatch: expected {tuple(self.weight.shape)}, " | |
| f"got {tuple(weight.shape)}" | |
| ) | |
| if weight_scales.shape != self.weight_scales.shape: | |
| raise ValueError( | |
| f"Weight scale shape mismatch: expected {tuple(self.weight_scales.shape)}, " | |
| f"got {tuple(weight_scales.shape)}" | |
| ) | |
| if input_scale.shape != self.input_scale.shape: | |
| raise ValueError( | |
| f"Input scale shape mismatch: expected {tuple(self.input_scale.shape)}, " | |
| f"got {tuple(input_scale.shape)}" | |
| ) | |
| if activation_scale.shape != self.activation_scale.shape: | |
| raise ValueError( | |
| f"Activation scale shape mismatch: expected {tuple(self.activation_scale.shape)}, " | |
| f"got {tuple(activation_scale.shape)}" | |
| ) | |
| self.weight.copy_(weight.to(dtype=self.quant_dtype)) | |
| self.weight_scales.copy_(weight_scales.to(dtype=torch.float32)) | |
| self.input_scale.copy_(input_scale.to(dtype=torch.float32)) | |
| self.activation_scale.copy_(activation_scale.to(dtype=torch.float32)) | |
| self._invalidate_cache() | |
| def extra_repr(self) -> str: | |
| return ( | |
| f"in_features={self.in_features}, out_features={self.out_features}, " | |
| f"weight_bits={self.weight_bits}, activation_bits={self.activation_bits}, " | |
| f"bias={self.bias is not None}" | |
| ) | |