File size: 2,788 Bytes
1fd542c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""
Uniform Post-Training Quantization for DiT models.
Implements W8A8 and W4A8 quantization as described in Q&C paper (Eq 1-3).
- Channel-wise quantization for weights
- Tensor-wise quantization for activations
"""
import torch
import torch.nn as nn
import torch.nn.functional as F

class UniformQuantizer:
    """Uniform quantizer following Eq 1-3 from the Q&C paper."""
    @staticmethod
    def quantize(X, bits, channel_wise=False):
        if channel_wise and X.dim() >= 2:
            reduce_dims = tuple(range(1, X.dim()))
            x_min = X.amin(dim=reduce_dims, keepdim=True)
            x_max = X.amax(dim=reduce_dims, keepdim=True)
        else:
            x_min = X.min()
            x_max = X.max()
        qmax = 2 ** bits - 1
        scale = (x_max - x_min) / qmax
        scale = torch.clamp(scale, min=1e-8)
        zero_point = torch.clamp(torch.round(-x_min / scale), 0, qmax)
        X_q = torch.clamp(torch.round(X / scale) + zero_point, 0, qmax)
        X_dq = (X_q - zero_point) * scale
        return X_dq

class QuantLinear(nn.Module):
    """Quantized Linear layer. Weights: channel-wise, Activations: tensor-wise."""
    def __init__(self, original_linear, w_bits=8, a_bits=8):
        super().__init__()
        self.in_features = original_linear.in_features
        self.out_features = original_linear.out_features
        self.w_bits = w_bits
        self.a_bits = a_bits
        self.weight = nn.Parameter(original_linear.weight.data.clone())
        self.bias = nn.Parameter(original_linear.bias.data.clone()) if original_linear.bias is not None else None
        self._quantized_weight = None
        self._weight_dirty = True

    def _get_quantized_weight(self):
        if self._weight_dirty or self._quantized_weight is None:
            self._quantized_weight = UniformQuantizer.quantize(self.weight, self.w_bits, channel_wise=True)
            self._weight_dirty = False
        return self._quantized_weight

    def forward(self, x):
        x_q = UniformQuantizer.quantize(x, self.a_bits, channel_wise=False)
        w_q = self._get_quantized_weight()
        return F.linear(x_q, w_q, self.bias)

def quantize_model(model, w_bits=8, a_bits=8, skip_patterns=None):
    if skip_patterns is None:
        skip_patterns = []
    count = 0
    for name, module in model.named_modules():
        for child_name, child in module.named_children():
            full_name = f"{name}.{child_name}" if name else child_name
            if isinstance(child, nn.Linear):
                if any(p in full_name for p in skip_patterns):
                    continue
                setattr(module, child_name, QuantLinear(child, w_bits, a_bits))
                count += 1
    print(f"Quantized {count} Linear layers to W{w_bits}A{a_bits}")
    return model