KnutJaegersberg's picture
Upload 91 files
b3c0032
import torch
from .matmul_had import matmul_hadU
import glog
import multiprocessing as mp
def flat_to_sym(V, N):
A = torch.zeros(N, N, dtype=V.dtype, device=V.device)
idxs = torch.tril_indices(N, N, device=V.device)
A[idxs.unbind()] = V
A[idxs[1, :], idxs[0, :]] = V
return A
def sym_to_flat(A):
N = A.shape[-1]
idxs = torch.tril_indices(N, N, device=A.device)
return A[idxs.unbind()]
def register_H_hook(module, device):
n = module.in_features
H = torch.zeros(n, n, dtype=torch.float64, device=device)
mu = torch.zeros(n, dtype=torch.float64, device=device)
ct = 0
def H_hook(module, x):
nonlocal H, mu, ct, n
x = x[0].reshape(-1, n).to(torch.float64)
mu.add_(x.sum(dim=0))
H.addmm_(x.T, x)
ct += len(x)
hook = module.register_forward_pre_hook(H_hook)
def done():
nonlocal H, mu, ct, hook
hook.remove()
return H.cpu(), mu.cpu(), ct
return done
def block_LDL(H, b):
n = H.shape[0]
assert (n % b == 0)
m = n // b
L = torch.linalg.cholesky(H)
DL = torch.diagonal(L.reshape(m, b, m, b), dim1=0, dim2=2).permute(2, 0, 1)
D = DL @ DL.permute(0, 2, 1)
# DLinv = torch.linalg.inv(DL)
L = L.view(n, m, b)
for i in range(m):
# L[:, i, :] = L[:, i, :] @ DLinv[i, :, :]
L[:, i, :] = torch.linalg.solve(DL[i, :, :], L[:, i, :], left=False)
L = L.reshape(n, n)
return (L, D)
def wrap_tokenizer(tokenizer, x, ctx_size):
return tokenizer(x, return_tensors='pt', truncation=True, padding=True, max_length=ctx_size)
def sample_devset(dataset, tokenizer, size=128, ctx_size=2048, nproc=1):
devset = torch.zeros((size, ctx_size), dtype=torch.int64)
saved = 0
if nproc > 1:
p = mp.Pool(nproc)
while saved < size:
seqs = [(tokenizer, dataset[torch.randint(len(dataset), (size,))]['text'], ctx_size) for _ in range(nproc)]
tokens = p.starmap(wrap_tokenizer, seqs)
for i in range(len(tokens)):
lens = tokens[i].attention_mask.sum(dim=-1)
good = torch.where(lens == ctx_size)[0]
if len(good) > 0:
if saved + len(good) > size:
good = good[:size - saved]
devset[saved: saved+len(good)] = tokens[i].input_ids[good]
saved += len(good)
print(saved)
else:
while saved < size:
tokens = tokenizer(dataset[torch.randint(len(dataset), (size,))]['text'],
return_tensors='pt',
truncation=True,
padding=True,
max_length=ctx_size)
lens = tokens.attention_mask.sum(dim=-1)
good = torch.where(lens == ctx_size)[0]
if len(good) > 0:
if saved + len(good) > size:
good = good[:size - saved]
devset[saved: saved+len(good)] = tokens.input_ids[good]
saved += len(good)
return devset
def load_quip(save_name, cb, args, device):
glog.info(f"loading cached compressed layer from path \"{save_name}\"")
dict_loaded = torch.load(save_name, map_location=torch.device('cuda', device))
SU = dict_loaded['SU'].to(device)
SV = dict_loaded['SV'].to(device)
Wscale = dict_loaded['Wscale'].to(device)
Qidxs = dict_loaded['Qidxs'].to(device)
n, m = len(SU), len(SV)
hatWr = cb.to(device).by_idxs(Qidxs, packed=(cb.packsz != 1)).view(m, n)
hatWr = hatWr * Wscale
del Wscale
if args.lora_rank > 0:
A = dict_loaded['A'].to(device)
B = dict_loaded['B'].to(device)
hatWr = hatWr + A @ B
del A, B
if args.incoh_mode == "had":
hatW = (matmul_hadU((matmul_hadU(hatWr) * SU).T) * SV).T
elif args.incoh_mode == "kron":
hatW = SV.T @ hatWr @ SU
else: raise NotImplementedError
del SU, SV
if args.rescale_WH:
hatW = hatW / dict_loaded['scaleWH'][None, :].to(device)
return hatW
def dtype_from_str(str):
dtype_map = {
'torch.int64': torch.int64,
'torch.int32': torch.int32,
'torch.int16': torch.int16,
'torch.uint8': torch.uint8,
}
return dtype_map[str]