sanskar753's picture
Upload qandc/quantizer.py
1fd542c verified
"""
Uniform Post-Training Quantization for DiT models.
Implements W8A8 and W4A8 quantization as described in Q&C paper (Eq 1-3).
- Channel-wise quantization for weights
- Tensor-wise quantization for activations
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class UniformQuantizer:
"""Uniform quantizer following Eq 1-3 from the Q&C paper."""
@staticmethod
def quantize(X, bits, channel_wise=False):
if channel_wise and X.dim() >= 2:
reduce_dims = tuple(range(1, X.dim()))
x_min = X.amin(dim=reduce_dims, keepdim=True)
x_max = X.amax(dim=reduce_dims, keepdim=True)
else:
x_min = X.min()
x_max = X.max()
qmax = 2 ** bits - 1
scale = (x_max - x_min) / qmax
scale = torch.clamp(scale, min=1e-8)
zero_point = torch.clamp(torch.round(-x_min / scale), 0, qmax)
X_q = torch.clamp(torch.round(X / scale) + zero_point, 0, qmax)
X_dq = (X_q - zero_point) * scale
return X_dq
class QuantLinear(nn.Module):
"""Quantized Linear layer. Weights: channel-wise, Activations: tensor-wise."""
def __init__(self, original_linear, w_bits=8, a_bits=8):
super().__init__()
self.in_features = original_linear.in_features
self.out_features = original_linear.out_features
self.w_bits = w_bits
self.a_bits = a_bits
self.weight = nn.Parameter(original_linear.weight.data.clone())
self.bias = nn.Parameter(original_linear.bias.data.clone()) if original_linear.bias is not None else None
self._quantized_weight = None
self._weight_dirty = True
def _get_quantized_weight(self):
if self._weight_dirty or self._quantized_weight is None:
self._quantized_weight = UniformQuantizer.quantize(self.weight, self.w_bits, channel_wise=True)
self._weight_dirty = False
return self._quantized_weight
def forward(self, x):
x_q = UniformQuantizer.quantize(x, self.a_bits, channel_wise=False)
w_q = self._get_quantized_weight()
return F.linear(x_q, w_q, self.bias)
def quantize_model(model, w_bits=8, a_bits=8, skip_patterns=None):
if skip_patterns is None:
skip_patterns = []
count = 0
for name, module in model.named_modules():
for child_name, child in module.named_children():
full_name = f"{name}.{child_name}" if name else child_name
if isinstance(child, nn.Linear):
if any(p in full_name for p in skip_patterns):
continue
setattr(module, child_name, QuantLinear(child, w_bits, a_bits))
count += 1
print(f"Quantized {count} Linear layers to W{w_bits}A{a_bits}")
return model