|
|
import torch |
|
|
from .matmul_had import matmul_hadU |
|
|
import glog |
|
|
import multiprocessing as mp |
|
|
|
|
|
def flat_to_sym(V, N): |
|
|
A = torch.zeros(N, N, dtype=V.dtype, device=V.device) |
|
|
idxs = torch.tril_indices(N, N, device=V.device) |
|
|
A[idxs.unbind()] = V |
|
|
A[idxs[1, :], idxs[0, :]] = V |
|
|
return A |
|
|
|
|
|
|
|
|
def sym_to_flat(A): |
|
|
N = A.shape[-1] |
|
|
idxs = torch.tril_indices(N, N, device=A.device) |
|
|
return A[idxs.unbind()] |
|
|
|
|
|
|
|
|
def register_H_hook(module, device): |
|
|
n = module.in_features |
|
|
H = torch.zeros(n, n, dtype=torch.float64, device=device) |
|
|
mu = torch.zeros(n, dtype=torch.float64, device=device) |
|
|
ct = 0 |
|
|
|
|
|
def H_hook(module, x): |
|
|
nonlocal H, mu, ct, n |
|
|
x = x[0].reshape(-1, n).to(torch.float64) |
|
|
mu.add_(x.sum(dim=0)) |
|
|
H.addmm_(x.T, x) |
|
|
ct += len(x) |
|
|
|
|
|
hook = module.register_forward_pre_hook(H_hook) |
|
|
|
|
|
def done(): |
|
|
nonlocal H, mu, ct, hook |
|
|
hook.remove() |
|
|
return H.cpu(), mu.cpu(), ct |
|
|
|
|
|
return done |
|
|
|
|
|
|
|
|
def block_LDL(H, b): |
|
|
n = H.shape[0] |
|
|
assert (n % b == 0) |
|
|
m = n // b |
|
|
L = torch.linalg.cholesky(H) |
|
|
DL = torch.diagonal(L.reshape(m, b, m, b), dim1=0, dim2=2).permute(2, 0, 1) |
|
|
D = DL @ DL.permute(0, 2, 1) |
|
|
|
|
|
L = L.view(n, m, b) |
|
|
for i in range(m): |
|
|
|
|
|
L[:, i, :] = torch.linalg.solve(DL[i, :, :], L[:, i, :], left=False) |
|
|
L = L.reshape(n, n) |
|
|
return (L, D) |
|
|
|
|
|
def wrap_tokenizer(tokenizer, x, ctx_size): |
|
|
return tokenizer(x, return_tensors='pt', truncation=True, padding=True, max_length=ctx_size) |
|
|
|
|
|
def sample_devset(dataset, tokenizer, size=128, ctx_size=2048, nproc=1): |
|
|
devset = torch.zeros((size, ctx_size), dtype=torch.int64) |
|
|
saved = 0 |
|
|
if nproc > 1: |
|
|
p = mp.Pool(nproc) |
|
|
while saved < size: |
|
|
seqs = [(tokenizer, dataset[torch.randint(len(dataset), (size,))]['text'], ctx_size) for _ in range(nproc)] |
|
|
tokens = p.starmap(wrap_tokenizer, seqs) |
|
|
for i in range(len(tokens)): |
|
|
lens = tokens[i].attention_mask.sum(dim=-1) |
|
|
good = torch.where(lens == ctx_size)[0] |
|
|
if len(good) > 0: |
|
|
if saved + len(good) > size: |
|
|
good = good[:size - saved] |
|
|
devset[saved: saved+len(good)] = tokens[i].input_ids[good] |
|
|
saved += len(good) |
|
|
print(saved) |
|
|
else: |
|
|
while saved < size: |
|
|
tokens = tokenizer(dataset[torch.randint(len(dataset), (size,))]['text'], |
|
|
return_tensors='pt', |
|
|
truncation=True, |
|
|
padding=True, |
|
|
max_length=ctx_size) |
|
|
lens = tokens.attention_mask.sum(dim=-1) |
|
|
good = torch.where(lens == ctx_size)[0] |
|
|
if len(good) > 0: |
|
|
if saved + len(good) > size: |
|
|
good = good[:size - saved] |
|
|
devset[saved: saved+len(good)] = tokens.input_ids[good] |
|
|
saved += len(good) |
|
|
return devset |
|
|
|
|
|
|
|
|
def load_quip(save_name, cb, args, device): |
|
|
glog.info(f"loading cached compressed layer from path \"{save_name}\"") |
|
|
dict_loaded = torch.load(save_name, map_location=torch.device('cuda', device)) |
|
|
SU = dict_loaded['SU'].to(device) |
|
|
SV = dict_loaded['SV'].to(device) |
|
|
Wscale = dict_loaded['Wscale'].to(device) |
|
|
Qidxs = dict_loaded['Qidxs'].to(device) |
|
|
n, m = len(SU), len(SV) |
|
|
hatWr = cb.to(device).by_idxs(Qidxs, packed=(cb.packsz != 1)).view(m, n) |
|
|
hatWr = hatWr * Wscale |
|
|
del Wscale |
|
|
if args.lora_rank > 0: |
|
|
A = dict_loaded['A'].to(device) |
|
|
B = dict_loaded['B'].to(device) |
|
|
hatWr = hatWr + A @ B |
|
|
del A, B |
|
|
if args.incoh_mode == "had": |
|
|
hatW = (matmul_hadU((matmul_hadU(hatWr) * SU).T) * SV).T |
|
|
elif args.incoh_mode == "kron": |
|
|
hatW = SV.T @ hatWr @ SU |
|
|
else: raise NotImplementedError |
|
|
del SU, SV |
|
|
if args.rescale_WH: |
|
|
hatW = hatW / dict_loaded['scaleWH'][None, :].to(device) |
|
|
return hatW |
|
|
|
|
|
|
|
|
def dtype_from_str(str): |
|
|
dtype_map = { |
|
|
'torch.int64': torch.int64, |
|
|
'torch.int32': torch.int32, |
|
|
'torch.int16': torch.int16, |
|
|
'torch.uint8': torch.uint8, |
|
|
} |
|
|
return dtype_map[str] |
|
|
|