| import torch | |
| from torch import nn | |
| import quiptools_cuda | |
| from lib.utils.matmul_had import matmul_hadU_cuda, matmul_hadUt_cuda | |
| def get_grid(): | |
| hintr = torch.arange(-8, 8) + 1 / 2 | |
| return hintr.unsqueeze(-1) | |
| _HI4B1C_CACHED = get_grid() | |
| _HI4B1C_NORM_CACHED = torch.diag(_HI4B1C_CACHED @ _HI4B1C_CACHED.T) | |
| class HI4B1C_codebook(nn.Module): | |
| def __init__(self, inference=False): | |
| super(HI4B1C_codebook, self).__init__() | |
| self.opt_scale = 2.97 | |
| self.codesz = 1 | |
| self.idx_dtype = torch.int32 | |
| self.packsz = 8 | |
| self.pack_out = False | |
| self.version = 0 | |
| self.register_buffer('grid', _HI4B1C_CACHED) | |
| if not inference: | |
| self.register_buffer('grid_norm', _HI4B1C_NORM_CACHED) | |
| ''' | |
| self.cuda() | |
| samples = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(1), torch.eye(1)).rsample([200000]).cuda() | |
| print(samples.shape) | |
| def fn_s(s): | |
| err = (self.quantize(samples*s, False)/s - samples).float().norm()**2 | |
| err = err.cpu() / torch.numel(samples) | |
| return err.cpu() | |
| import scipy | |
| print(scipy.optimize.minimize_scalar(fn_s, bounds=(0.1, 100))) | |
| exit() | |
| ''' | |
| def round(self, X, grid, grid_norm): | |
| assert X.shape[-1] == self.codesz | |
| Xqidx = (2 * X @ grid.T - grid_norm).argmax(-1) | |
| return grid[Xqidx], Xqidx | |
| def quantize(self, X, return_idx=True): | |
| vals, idx = self.round(X, self.grid, self.grid_norm) | |
| if not return_idx: | |
| return vals | |
| return vals, idx.to(self.idx_dtype) | |
| def maybe_pack_idxs(self, idxs): | |
| return \ | |
| (idxs[:, 0::self.packsz] << 4*7) + \ | |
| (idxs[:, 2::self.packsz] << 4*6) + \ | |
| (idxs[:, 4::self.packsz] << 4*5) + \ | |
| (idxs[:, 6::self.packsz] << 4*4) + \ | |
| (idxs[:, 1::self.packsz] << 4*3) + \ | |
| (idxs[:, 3::self.packsz] << 4*2) + \ | |
| (idxs[:, 5::self.packsz] << 4*1) + \ | |
| idxs[:, 7::self.packsz] | |
| def by_idxs(self, idxs, packed=False): | |
| if packed: | |
| idxs = idxs.repeat_interleave(self.packsz, dim=-1) | |
| idxs[:, 0::self.packsz] = (idxs[:, 0::self.packsz] >> 28) & 15 | |
| idxs[:, 2::self.packsz] = (idxs[:, 2::self.packsz] >> 24) & 15 | |
| idxs[:, 4::self.packsz] = (idxs[:, 4::self.packsz] >> 20) & 15 | |
| idxs[:, 6::self.packsz] = (idxs[:, 6::self.packsz] >> 16) & 15 | |
| idxs[:, 1::self.packsz] = (idxs[:, 1::self.packsz] >> 12) & 15 | |
| idxs[:, 3::self.packsz] = (idxs[:, 3::self.packsz] >> 8) & 15 | |
| idxs[:, 5::self.packsz] = (idxs[:, 5::self.packsz] >> 4) & 15 | |
| idxs[:, 7::self.packsz] = idxs[:, 7::self.packsz] & 15 | |
| return self.grid[idxs.int()] | |
| class QuantizedHI4B1CLinear(nn.Module): | |
| def __init__(self, device): | |
| super().__init__() | |
| self.codebook = HI4B1C_codebook(inference=True).to(torch.float16).to(device) | |
| def forward(self, | |
| input, | |
| Qidxs, | |
| SU, | |
| SV, | |
| Wscale, | |
| had_left, | |
| had_right, | |
| K_left, | |
| K_right, | |
| rank=-1, | |
| A=None, | |
| B=None, | |
| rescale_WH=False, | |
| scaleWH=None, | |
| packed=False): | |
| n, m = len(SU), len(SV) | |
| x = input.view(-1, n).to(torch.float32) | |
| if rescale_WH: | |
| x /= scaleWH | |
| x = x * SU | |
| x = matmul_hadUt_cuda(x, had_left, K_left) | |
| if rank > 0: | |
| Bx = x @ B.t().to(torch.float32) | |
| ABx = Bx @ A.t().to(torch.float32) | |
| num_scale = 1024 | |
| x = x / num_scale | |
| x = x.to(torch.float16) | |
| if packed: | |
| W_decompressed = torch.zeros(m, n, dtype=torch.float16, device=x.device) | |
| quiptools_cuda.decompress_hi4b1c_packed(Qidxs, self.codebook.grid, W_decompressed) | |
| else: | |
| W_decompressed = self.codebook.by_idxs(Qidxs, packed=False).reshape(-1, n) | |
| z = x @ W_decompressed.t() | |
| x = z.to(torch.float32) | |
| x = x * (Wscale * num_scale) | |
| if rank > 0: | |
| x = x + ABx.to(torch.float32) | |
| x = matmul_hadU_cuda(x, had_right, K_right) | |
| x = x * SV | |
| output = x.view(*input.shape[:-1], m) | |
| return output | |