File size: 6,772 Bytes
a216fa7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import argparse, math, os, random
import numpy as np
import torch
import torch.nn.functional as F
from torch_predictive_attn import ppmi_embed, learn_map, doc_index, apply_stack, features, log


def iter_chunks(tokens, eos, args, max_tokens, shuffle=False):
    xnp = tokens[:max_tokens]
    starts = np.flatnonzero(np.r_[True, xnp[:-1] == eos])
    ids = list(range(0, len(starts), args.chunk_docs))
    if shuffle:
        random.shuffle(ids)
    for i in ids:
        lo = starts[i]
        hi = starts[i + args.chunk_docs] if i + args.chunk_docs < len(starts) else len(xnp)
        yield xnp[lo:hi]


def chunk_features(xnp, E, Ps, Bs, eos, args, device):
    x = torch.tensor(xnp.astype(np.int64), device=device)
    seg, within = doc_index(x, eos)
    H, phis = apply_stack(x, E, Ps, Bs, within, args)
    Phi = features(H, within, phis, args.extra_context)
    y = torch.empty(len(x), device=device, dtype=torch.long)
    y[:-1] = x[1:]; y[-1] = eos
    m = torch.ones(len(x), device=device, dtype=torch.bool)
    m[-1] = False; m[:-1] &= seg[1:].eq(seg[:-1]); m &= x.ne(eos)
    return Phi[m].float(), y[m]


def eval_ppl(tokens, E, Ps, Bs, W, b, eos, args, device):
    nll = 0.0; n = 0
    with torch.no_grad():
        for xnp in iter_chunks(tokens, eos, args, args.eval_tokens, shuffle=False):
            X, y = chunk_features(xnp, E, Ps, Bs, eos, args, device)
            for i in range(0, len(y), args.batch):
                logits = X[i:i+args.batch] @ W + b
                nll += float(F.cross_entropy(logits, y[i:i+args.batch], reduction="sum"))
                n += len(y[i:i+args.batch])
    return math.exp(nll / max(1, n))


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--data", default="/workspace/glm")
    ap.add_argument("--spm_model", default="/workspace/glm/glm16k.model")
    ap.add_argument("--train_bin", default="/workspace/glm/glm_train.bin")
    ap.add_argument("--valid_bin", default="/workspace/glm/glm_valid.bin")
    ap.add_argument("--vocab", type=int, default=8192)
    ap.add_argument("--d", type=int, default=896)
    ap.add_argument("--r", type=int, default=320)
    ap.add_argument("--layers", type=int, default=10)
    ap.add_argument("--att_window", type=int, default=10)
    ap.add_argument("--temp", type=float, default=0.28)
    ap.add_argument("--window", type=int, default=10)
    ap.add_argument("--extra_context", type=int, default=1)
    ap.add_argument("--res_scale", type=float, default=0.07)
    ap.add_argument("--pred_scale", type=float, default=0.035)
    ap.add_argument("--pred_schedule", default="late")
    ap.add_argument("--orth_delta", type=int, default=1)
    ap.add_argument("--pred_norm", type=int, default=1)
    ap.add_argument("--pred_features", type=int, default=1)
    ap.add_argument("--map_lam", type=float, default=0.001)
    ap.add_argument("--cooc_tokens", type=int, default=3_600_000)
    ap.add_argument("--proj_tokens", type=int, default=3_600_000)
    ap.add_argument("--fit_tokens", type=int, default=3_600_000)
    ap.add_argument("--eval_tokens", type=int, default=159_631)
    ap.add_argument("--chunk_docs", type=int, default=8)
    ap.add_argument("--value_mode", default="dual_ridge_delta")
    ap.add_argument("--ridge_lam", type=float, default=10.0)
    ap.add_argument("--init_scale", type=float, default=0.05)
    ap.add_argument("--steps", type=int, default=800)
    ap.add_argument("--batch", type=int, default=2048)
    ap.add_argument("--lr", type=float, default=0.003)
    ap.add_argument("--wd", type=float, default=1e-4)
    ap.add_argument("--eval_every", type=int, default=100)
    ap.add_argument("--save", default="")
    ap.add_argument("--resume", default="")
    args = ap.parse_args()

    import sentencepiece as spm
    sp = spm.SentencePieceProcessor(model_file=args.spm_model)
    eos = sp.eos_id(); V = sp.get_piece_size()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    train = np.fromfile(args.train_bin, dtype=np.uint16)
    valid = np.fromfile(args.valid_bin, dtype=np.uint16)
    log("STREAM_CE device", device, "V", V, "train", len(train), "valid", len(valid))
    E = ppmi_embed(train, V, args.d, args.window, args.cooc_tokens, device)
    Ps, Bs = [], []
    for _ in range(args.layers):
        P, B = learn_map(train, E, Ps, Bs, eos, args, device)
        Ps.append(P); Bs.append(B)

    # Build ridge init streaming stats only.
    A = G = None
    for xnp in iter_chunks(train, eos, args, args.fit_tokens, shuffle=False):
        X, y = chunk_features(xnp, E, Ps, Bs, eos, args, device)
        if A is None:
            D = X.shape[1]
            A = torch.zeros((D, D), device=device, dtype=torch.float64)
            G = torch.zeros((D, V), device=device, dtype=torch.float64)
        Xd = X.double()
        A += Xd.T @ Xd
        G.index_add_(1, y, Xd.T)
    diag = torch.trace(A) / A.shape[0]
    W0 = torch.linalg.solve(A + args.ridge_lam * diag * torch.eye(A.shape[0], device=device, dtype=torch.float64), G).float()
    W = (args.init_scale * W0).detach().clone()
    b = torch.zeros(V, device=device)
    if args.resume and os.path.exists(args.resume):
        ck = torch.load(args.resume, map_location=device)
        W = ck["W"].to(device)
        b = ck["b"].to(device)
        log("resumed", args.resume, "ppl", ck.get("ppl"))
    W = W.requires_grad_(True)
    b = b.requires_grad_(True)
    opt = torch.optim.AdamW([W, b], lr=args.lr, weight_decay=args.wd)
    log("init_eval_start D", W.shape[0])
    best = eval_ppl(valid, E, Ps, Bs, W, b, eos, args, device)
    log(f"STREAM_CE init_ppl={best:.2f}")

    step = 0
    while step < args.steps:
        for xnp in iter_chunks(train, eos, args, args.fit_tokens, shuffle=True):
            X, y = chunk_features(xnp, E, Ps, Bs, eos, args, device)
            if len(y) == 0:
                continue
            idx = torch.randint(0, len(y), (min(args.batch, len(y)),), device=device)
            loss = F.cross_entropy(X[idx] @ W + b, y[idx])
            opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
            step += 1
            if step % args.eval_every == 0:
                ppl = eval_ppl(valid, E, Ps, Bs, W, b, eos, args, device)
                if ppl < best:
                    best = ppl
                    if args.save:
                        torch.save({"W": W.detach().cpu(), "b": b.detach().cpu(), "ppl": best, "args": vars(args)}, args.save)
                log(f"step={step} loss={float(loss):.4f} ppl={ppl:.2f} best={best:.2f}")
            if step >= args.steps:
                break
    log(f"STREAM_CE best_ppl={best:.2f}")
    if args.save:
        torch.save({"W": W.detach().cpu(), "b": b.detach().cpu(), "ppl": best, "args": vars(args)}, args.save)


if __name__ == "__main__":
    main()