File size: 4,470 Bytes
c1a41d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import torch
from torch import nn
import quiptools_cuda
from lib.utils.matmul_had import matmul_hadU_cuda, matmul_hadUt_cuda
def get_grid():
hintr = torch.arange(-8, 8) + 1 / 2
return hintr.unsqueeze(-1)
_HI4B1C_CACHED = get_grid()
_HI4B1C_NORM_CACHED = torch.diag(_HI4B1C_CACHED @ _HI4B1C_CACHED.T)
class HI4B1C_codebook(nn.Module):
def __init__(self, inference=False):
super(HI4B1C_codebook, self).__init__()
self.opt_scale = 2.97
self.codesz = 1
self.idx_dtype = torch.int32
self.packsz = 8
self.pack_out = False
self.version = 0
self.register_buffer('grid', _HI4B1C_CACHED)
if not inference:
self.register_buffer('grid_norm', _HI4B1C_NORM_CACHED)
'''
self.cuda()
samples = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(1), torch.eye(1)).rsample([200000]).cuda()
print(samples.shape)
def fn_s(s):
err = (self.quantize(samples*s, False)/s - samples).float().norm()**2
err = err.cpu() / torch.numel(samples)
return err.cpu()
import scipy
print(scipy.optimize.minimize_scalar(fn_s, bounds=(0.1, 100)))
exit()
'''
def round(self, X, grid, grid_norm):
assert X.shape[-1] == self.codesz
Xqidx = (2 * X @ grid.T - grid_norm).argmax(-1)
return grid[Xqidx], Xqidx
def quantize(self, X, return_idx=True):
vals, idx = self.round(X, self.grid, self.grid_norm)
if not return_idx:
return vals
return vals, idx.to(self.idx_dtype)
def maybe_pack_idxs(self, idxs):
return \
(idxs[:, 0::self.packsz] << 4*7) + \
(idxs[:, 2::self.packsz] << 4*6) + \
(idxs[:, 4::self.packsz] << 4*5) + \
(idxs[:, 6::self.packsz] << 4*4) + \
(idxs[:, 1::self.packsz] << 4*3) + \
(idxs[:, 3::self.packsz] << 4*2) + \
(idxs[:, 5::self.packsz] << 4*1) + \
idxs[:, 7::self.packsz]
def by_idxs(self, idxs, packed=False):
if packed:
idxs = idxs.repeat_interleave(self.packsz, dim=-1)
idxs[:, 0::self.packsz] = (idxs[:, 0::self.packsz] >> 28) & 15
idxs[:, 2::self.packsz] = (idxs[:, 2::self.packsz] >> 24) & 15
idxs[:, 4::self.packsz] = (idxs[:, 4::self.packsz] >> 20) & 15
idxs[:, 6::self.packsz] = (idxs[:, 6::self.packsz] >> 16) & 15
idxs[:, 1::self.packsz] = (idxs[:, 1::self.packsz] >> 12) & 15
idxs[:, 3::self.packsz] = (idxs[:, 3::self.packsz] >> 8) & 15
idxs[:, 5::self.packsz] = (idxs[:, 5::self.packsz] >> 4) & 15
idxs[:, 7::self.packsz] = idxs[:, 7::self.packsz] & 15
return self.grid[idxs.int()]
class QuantizedHI4B1CLinear(nn.Module):
def __init__(self, device):
super().__init__()
self.codebook = HI4B1C_codebook(inference=True).to(torch.float16).to(device)
def forward(self,
input,
Qidxs,
SU,
SV,
Wscale,
had_left,
had_right,
K_left,
K_right,
rank=-1,
A=None,
B=None,
rescale_WH=False,
scaleWH=None,
packed=False):
n, m = len(SU), len(SV)
x = input.view(-1, n).to(torch.float32)
if rescale_WH:
x /= scaleWH
x = x * SU
x = matmul_hadUt_cuda(x, had_left, K_left)
if rank > 0:
Bx = x @ B.t().to(torch.float32)
ABx = Bx @ A.t().to(torch.float32)
num_scale = 1024
x = x / num_scale
x = x.to(torch.float16)
if packed:
W_decompressed = torch.zeros(m, n, dtype=torch.float16, device=x.device)
quiptools_cuda.decompress_hi4b1c_packed(Qidxs, self.codebook.grid, W_decompressed)
else:
W_decompressed = self.codebook.by_idxs(Qidxs, packed=False).reshape(-1, n)
z = x @ W_decompressed.t()
x = z.to(torch.float32)
x = x * (Wscale * num_scale)
if rank > 0:
x = x + ABx.to(torch.float32)
x = matmul_hadU_cuda(x, had_right, K_right)
x = x * SV
output = x.view(*input.shape[:-1], m)
return output
|