File size: 4,295 Bytes
c1a41d7 b3c0032 c1a41d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import torch
from .matmul_had import matmul_hadU
import glog
import multiprocessing as mp
def flat_to_sym(V, N):
A = torch.zeros(N, N, dtype=V.dtype, device=V.device)
idxs = torch.tril_indices(N, N, device=V.device)
A[idxs.unbind()] = V
A[idxs[1, :], idxs[0, :]] = V
return A
def sym_to_flat(A):
N = A.shape[-1]
idxs = torch.tril_indices(N, N, device=A.device)
return A[idxs.unbind()]
def register_H_hook(module, device):
n = module.in_features
H = torch.zeros(n, n, dtype=torch.float64, device=device)
mu = torch.zeros(n, dtype=torch.float64, device=device)
ct = 0
def H_hook(module, x):
nonlocal H, mu, ct, n
x = x[0].reshape(-1, n).to(torch.float64)
mu.add_(x.sum(dim=0))
H.addmm_(x.T, x)
ct += len(x)
hook = module.register_forward_pre_hook(H_hook)
def done():
nonlocal H, mu, ct, hook
hook.remove()
return H.cpu(), mu.cpu(), ct
return done
def block_LDL(H, b):
n = H.shape[0]
assert (n % b == 0)
m = n // b
L = torch.linalg.cholesky(H)
DL = torch.diagonal(L.reshape(m, b, m, b), dim1=0, dim2=2).permute(2, 0, 1)
D = DL @ DL.permute(0, 2, 1)
# DLinv = torch.linalg.inv(DL)
L = L.view(n, m, b)
for i in range(m):
# L[:, i, :] = L[:, i, :] @ DLinv[i, :, :]
L[:, i, :] = torch.linalg.solve(DL[i, :, :], L[:, i, :], left=False)
L = L.reshape(n, n)
return (L, D)
def wrap_tokenizer(tokenizer, x, ctx_size):
return tokenizer(x, return_tensors='pt', truncation=True, padding=True, max_length=ctx_size)
def sample_devset(dataset, tokenizer, size=128, ctx_size=2048, nproc=1):
devset = torch.zeros((size, ctx_size), dtype=torch.int64)
saved = 0
if nproc > 1:
p = mp.Pool(nproc)
while saved < size:
seqs = [(tokenizer, dataset[torch.randint(len(dataset), (size,))]['text'], ctx_size) for _ in range(nproc)]
tokens = p.starmap(wrap_tokenizer, seqs)
for i in range(len(tokens)):
lens = tokens[i].attention_mask.sum(dim=-1)
good = torch.where(lens == ctx_size)[0]
if len(good) > 0:
if saved + len(good) > size:
good = good[:size - saved]
devset[saved: saved+len(good)] = tokens[i].input_ids[good]
saved += len(good)
print(saved)
else:
while saved < size:
tokens = tokenizer(dataset[torch.randint(len(dataset), (size,))]['text'],
return_tensors='pt',
truncation=True,
padding=True,
max_length=ctx_size)
lens = tokens.attention_mask.sum(dim=-1)
good = torch.where(lens == ctx_size)[0]
if len(good) > 0:
if saved + len(good) > size:
good = good[:size - saved]
devset[saved: saved+len(good)] = tokens.input_ids[good]
saved += len(good)
return devset
def load_quip(save_name, cb, args, device):
glog.info(f"loading cached compressed layer from path \"{save_name}\"")
dict_loaded = torch.load(save_name, map_location=torch.device('cuda', device))
SU = dict_loaded['SU'].to(device)
SV = dict_loaded['SV'].to(device)
Wscale = dict_loaded['Wscale'].to(device)
Qidxs = dict_loaded['Qidxs'].to(device)
n, m = len(SU), len(SV)
hatWr = cb.to(device).by_idxs(Qidxs, packed=(cb.packsz != 1)).view(m, n)
hatWr = hatWr * Wscale
del Wscale
if args.lora_rank > 0:
A = dict_loaded['A'].to(device)
B = dict_loaded['B'].to(device)
hatWr = hatWr + A @ B
del A, B
if args.incoh_mode == "had":
hatW = (matmul_hadU((matmul_hadU(hatWr) * SU).T) * SV).T
elif args.incoh_mode == "kron":
hatW = SV.T @ hatWr @ SU
else: raise NotImplementedError
del SU, SV
if args.rescale_WH:
hatW = hatW / dict_loaded['scaleWH'][None, :].to(device)
return hatW
def dtype_from_str(str):
dtype_map = {
'torch.int64': torch.int64,
'torch.int32': torch.int32,
'torch.int16': torch.int16,
'torch.uint8': torch.uint8,
}
return dtype_map[str]
|