File size: 4,470 Bytes
c1a41d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch
from torch import nn
import quiptools_cuda

from lib.utils.matmul_had import matmul_hadU_cuda, matmul_hadUt_cuda


def get_grid():
    hintr = torch.arange(-8, 8) + 1 / 2
    return hintr.unsqueeze(-1)


_HI4B1C_CACHED = get_grid()
_HI4B1C_NORM_CACHED = torch.diag(_HI4B1C_CACHED @ _HI4B1C_CACHED.T)


class HI4B1C_codebook(nn.Module):

    def __init__(self, inference=False):
        super(HI4B1C_codebook, self).__init__()
        self.opt_scale = 2.97
        self.codesz = 1
        self.idx_dtype = torch.int32
        self.packsz = 8
        self.pack_out = False
        self.version = 0

        self.register_buffer('grid', _HI4B1C_CACHED)
        if not inference:
            self.register_buffer('grid_norm', _HI4B1C_NORM_CACHED)
            '''
            self.cuda()
            samples = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(1), torch.eye(1)).rsample([200000]).cuda()
            print(samples.shape)
            def fn_s(s):
                err = (self.quantize(samples*s, False)/s - samples).float().norm()**2
                err = err.cpu() / torch.numel(samples)
                return err.cpu()        
            import scipy
            print(scipy.optimize.minimize_scalar(fn_s, bounds=(0.1, 100)))
            exit()
            '''

    def round(self, X, grid, grid_norm):
        assert X.shape[-1] == self.codesz
        Xqidx = (2 * X @ grid.T - grid_norm).argmax(-1)
        return grid[Xqidx], Xqidx

    def quantize(self, X, return_idx=True):
        vals, idx = self.round(X, self.grid, self.grid_norm)
        if not return_idx:
            return vals
        return vals, idx.to(self.idx_dtype)

    def maybe_pack_idxs(self, idxs):
        return \
            (idxs[:, 0::self.packsz] << 4*7) + \
            (idxs[:, 2::self.packsz] << 4*6) + \
            (idxs[:, 4::self.packsz] << 4*5) + \
            (idxs[:, 6::self.packsz] << 4*4) + \
            (idxs[:, 1::self.packsz] << 4*3) + \
            (idxs[:, 3::self.packsz] << 4*2) + \
            (idxs[:, 5::self.packsz] << 4*1) + \
            idxs[:, 7::self.packsz]

    def by_idxs(self, idxs, packed=False):
        if packed:
            idxs = idxs.repeat_interleave(self.packsz, dim=-1)
            idxs[:, 0::self.packsz] = (idxs[:, 0::self.packsz] >> 28) & 15
            idxs[:, 2::self.packsz] = (idxs[:, 2::self.packsz] >> 24) & 15
            idxs[:, 4::self.packsz] = (idxs[:, 4::self.packsz] >> 20) & 15
            idxs[:, 6::self.packsz] = (idxs[:, 6::self.packsz] >> 16) & 15
            idxs[:, 1::self.packsz] = (idxs[:, 1::self.packsz] >> 12) & 15
            idxs[:, 3::self.packsz] = (idxs[:, 3::self.packsz] >> 8) & 15
            idxs[:, 5::self.packsz] = (idxs[:, 5::self.packsz] >> 4) & 15
            idxs[:, 7::self.packsz] = idxs[:, 7::self.packsz] & 15

        return self.grid[idxs.int()]


class QuantizedHI4B1CLinear(nn.Module):

    def __init__(self, device):
        super().__init__()
        self.codebook = HI4B1C_codebook(inference=True).to(torch.float16).to(device)

    def forward(self,
                input,
                Qidxs,
                SU,
                SV,
                Wscale,
                had_left,
                had_right,
                K_left,
                K_right,
                rank=-1,
                A=None,
                B=None,
                rescale_WH=False,
                scaleWH=None,
                packed=False):
        n, m = len(SU), len(SV)

        x = input.view(-1, n).to(torch.float32)
        if rescale_WH:
            x /= scaleWH
        x = x * SU
        x = matmul_hadUt_cuda(x, had_left, K_left)

        if rank > 0:
            Bx = x @ B.t().to(torch.float32)
            ABx = Bx @ A.t().to(torch.float32)

        num_scale = 1024
        x = x / num_scale
        x = x.to(torch.float16)

        if packed:
            W_decompressed = torch.zeros(m, n, dtype=torch.float16, device=x.device)
            quiptools_cuda.decompress_hi4b1c_packed(Qidxs, self.codebook.grid, W_decompressed)
        else:
            W_decompressed = self.codebook.by_idxs(Qidxs, packed=False).reshape(-1, n)

        z = x @ W_decompressed.t()

        x = z.to(torch.float32)
        x = x * (Wscale * num_scale)

        if rank > 0:
            x = x + ABx.to(torch.float32)

        x = matmul_hadU_cuda(x, had_right, K_right)
        x = x * SV

        output = x.view(*input.shape[:-1], m)

        return output