| import torch |
| from .matmul_had import matmul_hadU |
| import glog |
| import multiprocessing as mp |
|
|
| def flat_to_sym(V, N): |
| A = torch.zeros(N, N, dtype=V.dtype, device=V.device) |
| idxs = torch.tril_indices(N, N, device=V.device) |
| A[idxs.unbind()] = V |
| A[idxs[1, :], idxs[0, :]] = V |
| return A |
|
|
|
|
| def sym_to_flat(A): |
| N = A.shape[-1] |
| idxs = torch.tril_indices(N, N, device=A.device) |
| return A[idxs.unbind()] |
|
|
|
|
| def register_H_hook(module, device): |
| n = module.in_features |
| H = torch.zeros(n, n, dtype=torch.float64, device=device) |
| mu = torch.zeros(n, dtype=torch.float64, device=device) |
| ct = 0 |
|
|
| def H_hook(module, x): |
| nonlocal H, mu, ct, n |
| x = x[0].reshape(-1, n).to(torch.float64) |
| mu.add_(x.sum(dim=0)) |
| H.addmm_(x.T, x) |
| ct += len(x) |
|
|
| hook = module.register_forward_pre_hook(H_hook) |
|
|
| def done(): |
| nonlocal H, mu, ct, hook |
| hook.remove() |
| return H.cpu(), mu.cpu(), ct |
|
|
| return done |
|
|
|
|
| def block_LDL(H, b): |
| n = H.shape[0] |
| assert (n % b == 0) |
| m = n // b |
| L = torch.linalg.cholesky(H) |
| DL = torch.diagonal(L.reshape(m, b, m, b), dim1=0, dim2=2).permute(2, 0, 1) |
| D = DL @ DL.permute(0, 2, 1) |
| |
| L = L.view(n, m, b) |
| for i in range(m): |
| |
| L[:, i, :] = torch.linalg.solve(DL[i, :, :], L[:, i, :], left=False) |
| L = L.reshape(n, n) |
| return (L, D) |
|
|
| def wrap_tokenizer(tokenizer, x, ctx_size): |
| return tokenizer(x, return_tensors='pt', truncation=True, padding=True, max_length=ctx_size) |
|
|
| def sample_devset(dataset, tokenizer, size=128, ctx_size=2048, nproc=1): |
| devset = torch.zeros((size, ctx_size), dtype=torch.int64) |
| saved = 0 |
| if nproc > 1: |
| p = mp.Pool(nproc) |
| while saved < size: |
| seqs = [(tokenizer, dataset[torch.randint(len(dataset), (size,))]['text'], ctx_size) for _ in range(nproc)] |
| tokens = p.starmap(wrap_tokenizer, seqs) |
| for i in range(len(tokens)): |
| lens = tokens[i].attention_mask.sum(dim=-1) |
| good = torch.where(lens == ctx_size)[0] |
| if len(good) > 0: |
| if saved + len(good) > size: |
| good = good[:size - saved] |
| devset[saved: saved+len(good)] = tokens[i].input_ids[good] |
| saved += len(good) |
| print(saved) |
| else: |
| while saved < size: |
| tokens = tokenizer(dataset[torch.randint(len(dataset), (size,))]['text'], |
| return_tensors='pt', |
| truncation=True, |
| padding=True, |
| max_length=ctx_size) |
| lens = tokens.attention_mask.sum(dim=-1) |
| good = torch.where(lens == ctx_size)[0] |
| if len(good) > 0: |
| if saved + len(good) > size: |
| good = good[:size - saved] |
| devset[saved: saved+len(good)] = tokens.input_ids[good] |
| saved += len(good) |
| return devset |
|
|
|
|
| def load_quip(save_name, cb, args, device): |
| glog.info(f"loading cached compressed layer from path \"{save_name}\"") |
| dict_loaded = torch.load(save_name, map_location=torch.device('cuda', device)) |
| SU = dict_loaded['SU'].to(device) |
| SV = dict_loaded['SV'].to(device) |
| Wscale = dict_loaded['Wscale'].to(device) |
| Qidxs = dict_loaded['Qidxs'].to(device) |
| n, m = len(SU), len(SV) |
| hatWr = cb.to(device).by_idxs(Qidxs, packed=(cb.packsz != 1)).view(m, n) |
| hatWr = hatWr * Wscale |
| del Wscale |
| if args.lora_rank > 0: |
| A = dict_loaded['A'].to(device) |
| B = dict_loaded['B'].to(device) |
| hatWr = hatWr + A @ B |
| del A, B |
| if args.incoh_mode == "had": |
| hatW = (matmul_hadU((matmul_hadU(hatWr) * SU).T) * SV).T |
| elif args.incoh_mode == "kron": |
| hatW = SV.T @ hatWr @ SU |
| else: raise NotImplementedError |
| del SU, SV |
| if args.rescale_WH: |
| hatW = hatW / dict_loaded['scaleWH'][None, :].to(device) |
| return hatW |
|
|
|
|
| def dtype_from_str(str): |
| dtype_map = { |
| 'torch.int64': torch.int64, |
| 'torch.int32': torch.int32, |
| 'torch.int16': torch.int16, |
| 'torch.uint8': torch.uint8, |
| } |
| return dtype_map[str] |
|
|