| """ |
| Uniform Post-Training Quantization for DiT models. |
| Implements W8A8 and W4A8 quantization as described in Q&C paper (Eq 1-3). |
| - Channel-wise quantization for weights |
| - Tensor-wise quantization for activations |
| """ |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| class UniformQuantizer: |
| """Uniform quantizer following Eq 1-3 from the Q&C paper.""" |
| @staticmethod |
| def quantize(X, bits, channel_wise=False): |
| if channel_wise and X.dim() >= 2: |
| reduce_dims = tuple(range(1, X.dim())) |
| x_min = X.amin(dim=reduce_dims, keepdim=True) |
| x_max = X.amax(dim=reduce_dims, keepdim=True) |
| else: |
| x_min = X.min() |
| x_max = X.max() |
| qmax = 2 ** bits - 1 |
| scale = (x_max - x_min) / qmax |
| scale = torch.clamp(scale, min=1e-8) |
| zero_point = torch.clamp(torch.round(-x_min / scale), 0, qmax) |
| X_q = torch.clamp(torch.round(X / scale) + zero_point, 0, qmax) |
| X_dq = (X_q - zero_point) * scale |
| return X_dq |
|
|
| class QuantLinear(nn.Module): |
| """Quantized Linear layer. Weights: channel-wise, Activations: tensor-wise.""" |
| def __init__(self, original_linear, w_bits=8, a_bits=8): |
| super().__init__() |
| self.in_features = original_linear.in_features |
| self.out_features = original_linear.out_features |
| self.w_bits = w_bits |
| self.a_bits = a_bits |
| self.weight = nn.Parameter(original_linear.weight.data.clone()) |
| self.bias = nn.Parameter(original_linear.bias.data.clone()) if original_linear.bias is not None else None |
| self._quantized_weight = None |
| self._weight_dirty = True |
|
|
| def _get_quantized_weight(self): |
| if self._weight_dirty or self._quantized_weight is None: |
| self._quantized_weight = UniformQuantizer.quantize(self.weight, self.w_bits, channel_wise=True) |
| self._weight_dirty = False |
| return self._quantized_weight |
|
|
| def forward(self, x): |
| x_q = UniformQuantizer.quantize(x, self.a_bits, channel_wise=False) |
| w_q = self._get_quantized_weight() |
| return F.linear(x_q, w_q, self.bias) |
|
|
| def quantize_model(model, w_bits=8, a_bits=8, skip_patterns=None): |
| if skip_patterns is None: |
| skip_patterns = [] |
| count = 0 |
| for name, module in model.named_modules(): |
| for child_name, child in module.named_children(): |
| full_name = f"{name}.{child_name}" if name else child_name |
| if isinstance(child, nn.Linear): |
| if any(p in full_name for p in skip_patterns): |
| continue |
| setattr(module, child_name, QuantLinear(child, w_bits, a_bits)) |
| count += 1 |
| print(f"Quantized {count} Linear layers to W{w_bits}A{a_bits}") |
| return model |
|
|