File size: 4,295 Bytes
c1a41d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3c0032
c1a41d7
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch
from .matmul_had import matmul_hadU
import glog
import multiprocessing as mp

def flat_to_sym(V, N):
    A = torch.zeros(N, N, dtype=V.dtype, device=V.device)
    idxs = torch.tril_indices(N, N, device=V.device)
    A[idxs.unbind()] = V
    A[idxs[1, :], idxs[0, :]] = V
    return A


def sym_to_flat(A):
    N = A.shape[-1]
    idxs = torch.tril_indices(N, N, device=A.device)
    return A[idxs.unbind()]


def register_H_hook(module, device):
    n = module.in_features
    H = torch.zeros(n, n, dtype=torch.float64, device=device)
    mu = torch.zeros(n, dtype=torch.float64, device=device)
    ct = 0

    def H_hook(module, x):
        nonlocal H, mu, ct, n
        x = x[0].reshape(-1, n).to(torch.float64)
        mu.add_(x.sum(dim=0))
        H.addmm_(x.T, x)
        ct += len(x)

    hook = module.register_forward_pre_hook(H_hook)

    def done():
        nonlocal H, mu, ct, hook
        hook.remove()
        return H.cpu(), mu.cpu(), ct

    return done


def block_LDL(H, b): 
    n = H.shape[0]
    assert (n % b == 0)
    m = n // b
    L = torch.linalg.cholesky(H)
    DL = torch.diagonal(L.reshape(m, b, m, b), dim1=0, dim2=2).permute(2, 0, 1)
    D = DL @ DL.permute(0, 2, 1)
    # DLinv = torch.linalg.inv(DL)
    L = L.view(n, m, b)
    for i in range(m):
        # L[:, i, :] = L[:, i, :] @ DLinv[i, :, :]
        L[:, i, :] = torch.linalg.solve(DL[i, :, :], L[:, i, :], left=False)
    L = L.reshape(n, n)
    return (L, D)

def wrap_tokenizer(tokenizer, x, ctx_size):
    return tokenizer(x, return_tensors='pt', truncation=True, padding=True, max_length=ctx_size)

def sample_devset(dataset, tokenizer, size=128, ctx_size=2048, nproc=1):
    devset = torch.zeros((size, ctx_size), dtype=torch.int64)
    saved = 0
    if nproc > 1:
        p = mp.Pool(nproc)
        while saved < size:
            seqs = [(tokenizer, dataset[torch.randint(len(dataset), (size,))]['text'], ctx_size) for _ in range(nproc)]
            tokens = p.starmap(wrap_tokenizer, seqs)
            for i in range(len(tokens)):
                lens = tokens[i].attention_mask.sum(dim=-1)
                good = torch.where(lens == ctx_size)[0]
                if len(good) > 0:
                    if saved + len(good) > size:
                        good = good[:size - saved]
                    devset[saved: saved+len(good)] = tokens[i].input_ids[good]
                    saved += len(good)
                    print(saved)
    else:
        while saved < size:
         tokens = tokenizer(dataset[torch.randint(len(dataset), (size,))]['text'],
                            return_tensors='pt',
                            truncation=True,
                            padding=True,
                            max_length=ctx_size)
         lens = tokens.attention_mask.sum(dim=-1)
         good = torch.where(lens == ctx_size)[0]
         if len(good) > 0:
             if saved + len(good) > size:
                 good = good[:size - saved]
             devset[saved: saved+len(good)] = tokens.input_ids[good]
             saved += len(good)
    return devset


def load_quip(save_name, cb, args, device):
    glog.info(f"loading cached compressed layer from path \"{save_name}\"")
    dict_loaded = torch.load(save_name, map_location=torch.device('cuda', device))
    SU = dict_loaded['SU'].to(device)
    SV = dict_loaded['SV'].to(device)
    Wscale = dict_loaded['Wscale'].to(device)
    Qidxs = dict_loaded['Qidxs'].to(device)
    n, m = len(SU), len(SV)
    hatWr = cb.to(device).by_idxs(Qidxs, packed=(cb.packsz != 1)).view(m, n)
    hatWr = hatWr * Wscale
    del Wscale
    if args.lora_rank > 0:
        A = dict_loaded['A'].to(device)
        B = dict_loaded['B'].to(device)
        hatWr = hatWr + A @ B
        del A, B
    if args.incoh_mode == "had":
        hatW = (matmul_hadU((matmul_hadU(hatWr) * SU).T) * SV).T
    elif args.incoh_mode == "kron":
        hatW = SV.T @ hatWr @ SU
    else: raise NotImplementedError
    del SU, SV
    if args.rescale_WH:
        hatW = hatW / dict_loaded['scaleWH'][None, :].to(device)
    return hatW


def dtype_from_str(str):
    dtype_map = {
        'torch.int64': torch.int64,
        'torch.int32': torch.int32,
        'torch.int16': torch.int16,
        'torch.uint8': torch.uint8,
    }
    return dtype_map[str]