File size: 4,100 Bytes
8d8ae43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import torch.nn as nn
import torch.nn.functional as F

from .modules import sparse as sp


def _pack_4bit_tensor(q: torch.Tensor) -> torch.Tensor:
    q = q.to(torch.int8)
    q = torch.clamp(q, -8, 7)
    q_u = ((q + 8) & 0xF).to(torch.uint8).reshape(-1)
    if q_u.numel() % 2 == 1:
        q_u = torch.cat([q_u, torch.zeros(1, dtype=torch.uint8, device=q_u.device)])
    q_u = q_u.view(-1, 2)
    return q_u[:, 0] | (q_u[:, 1] << 4)


def _unpack_4bit_tensor(packed: torch.Tensor, shape: torch.Size) -> torch.Tensor:
    packed = packed.reshape(-1)
    low = packed & 0xF
    high = packed >> 4
    q = torch.empty((packed.numel() * 2,), dtype=torch.int8, device=packed.device)
    q[0::2] = low.to(torch.int8)
    q[1::2] = high.to(torch.int8)
    q = q[: torch.prod(torch.tensor(shape, dtype=torch.int64)).item()]
    q = q.view(shape)
    return q - 8


def _quantize_weight(weight: torch.Tensor, bits: int = 4) -> tuple[torch.Tensor, torch.Tensor]:
    if bits != 4:
        raise ValueError('Only 4-bit quantization is supported.')

    flat_weight = weight.view(weight.size(0), -1)
    max_vals = flat_weight.abs().amax(dim=1)
    scales = torch.where(max_vals > 0, max_vals / 7.0, torch.ones_like(max_vals))
    scales = scales.to(weight.dtype)

    q_weight = torch.clamp((flat_weight / scales[:, None]).round(), -8, 7).to(torch.int8)
    q_weight = q_weight.view(weight.shape)
    packed = _pack_4bit_tensor(q_weight)
    return packed, scales


class QuantizedLinear(nn.Module):
    def __init__(

        self,

        in_features: int,

        out_features: int,

        bias: bool = True,

        dtype: torch.dtype = torch.float16,

        bits: int = 4,

    ):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.bits = bits
        self.dtype = dtype

        self.register_buffer('weight_packed', torch.empty(0, dtype=torch.uint8))
        self.register_buffer('weight_scales', torch.empty((out_features,), dtype=self.dtype))

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features, dtype=self.dtype))
        else:
            self.bias = None

    @classmethod
    def from_float(cls, module: nn.Linear, bits: int = 4, dtype: torch.dtype = torch.float16):
        quantized = cls(module.in_features, module.out_features, module.bias is not None, dtype=dtype, bits=bits)
        packed, scales = _quantize_weight(module.weight.data, bits=bits)
        quantized.weight_packed = packed.to(dtype=torch.uint8)
        quantized.weight_scales = scales.to(dtype=dtype)
        if module.bias is not None:
            quantized.bias.data = module.bias.data.to(dtype=dtype)
        return quantized

    def dequantize_weight(self) -> torch.Tensor:
        q = _unpack_4bit_tensor(self.weight_packed, torch.Size((self.out_features, self.in_features)))
        scales = self.weight_scales.view(self.out_features, *([1] * (q.ndim - 1)))
        return q.to(self.dtype) * scales

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        input = input.to(self.dtype)
        weight = self.dequantize_weight()
        return F.linear(input, weight, self.bias)


class QuantizedSparseLinear(QuantizedLinear):
    def forward(self, input):
        feats = input.feats.to(self.dtype)
        weight = self.dequantize_weight()
        return input.replace(F.linear(feats, weight, self.bias))


def quantize_model(model: nn.Module, bits: int = 4, dtype: torch.dtype = torch.float16) -> None:
    """Replace linear modules in a model with 4-bit quantized equivalents."""
    for name, child in list(model.named_children()):
        if isinstance(child, sp.SparseLinear):
            model._modules[name] = QuantizedSparseLinear.from_float(child, bits=bits, dtype=dtype)
        elif isinstance(child, nn.Linear):
            model._modules[name] = QuantizedLinear.from_float(child, bits=bits, dtype=dtype)
        else:
            quantize_model(child, bits=bits, dtype=dtype)