| import argparse, math, os, random |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from torch_predictive_attn import ppmi_embed, learn_map, doc_index, apply_stack, features, log |
|
|
|
|
| def iter_chunks(tokens, eos, args, max_tokens, shuffle=False): |
| xnp = tokens[:max_tokens] |
| starts = np.flatnonzero(np.r_[True, xnp[:-1] == eos]) |
| ids = list(range(0, len(starts), args.chunk_docs)) |
| if shuffle: |
| random.shuffle(ids) |
| for i in ids: |
| lo = starts[i] |
| hi = starts[i + args.chunk_docs] if i + args.chunk_docs < len(starts) else len(xnp) |
| yield xnp[lo:hi] |
|
|
|
|
| def chunk_features(xnp, E, Ps, Bs, eos, args, device): |
| x = torch.tensor(xnp.astype(np.int64), device=device) |
| seg, within = doc_index(x, eos) |
| H, phis = apply_stack(x, E, Ps, Bs, within, args) |
| Phi = features(H, within, phis, args.extra_context) |
| y = torch.empty(len(x), device=device, dtype=torch.long) |
| y[:-1] = x[1:]; y[-1] = eos |
| m = torch.ones(len(x), device=device, dtype=torch.bool) |
| m[-1] = False; m[:-1] &= seg[1:].eq(seg[:-1]); m &= x.ne(eos) |
| return Phi[m].float(), y[m] |
|
|
|
|
| def eval_ppl(tokens, E, Ps, Bs, W, b, eos, args, device): |
| nll = 0.0; n = 0 |
| with torch.no_grad(): |
| for xnp in iter_chunks(tokens, eos, args, args.eval_tokens, shuffle=False): |
| X, y = chunk_features(xnp, E, Ps, Bs, eos, args, device) |
| for i in range(0, len(y), args.batch): |
| logits = X[i:i+args.batch] @ W + b |
| nll += float(F.cross_entropy(logits, y[i:i+args.batch], reduction="sum")) |
| n += len(y[i:i+args.batch]) |
| return math.exp(nll / max(1, n)) |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--data", default="/workspace/glm") |
| ap.add_argument("--spm_model", default="/workspace/glm/glm16k.model") |
| ap.add_argument("--train_bin", default="/workspace/glm/glm_train.bin") |
| ap.add_argument("--valid_bin", default="/workspace/glm/glm_valid.bin") |
| ap.add_argument("--vocab", type=int, default=8192) |
| ap.add_argument("--d", type=int, default=896) |
| ap.add_argument("--r", type=int, default=320) |
| ap.add_argument("--layers", type=int, default=10) |
| ap.add_argument("--att_window", type=int, default=10) |
| ap.add_argument("--temp", type=float, default=0.28) |
| ap.add_argument("--window", type=int, default=10) |
| ap.add_argument("--extra_context", type=int, default=1) |
| ap.add_argument("--res_scale", type=float, default=0.07) |
| ap.add_argument("--pred_scale", type=float, default=0.035) |
| ap.add_argument("--pred_schedule", default="late") |
| ap.add_argument("--orth_delta", type=int, default=1) |
| ap.add_argument("--pred_norm", type=int, default=1) |
| ap.add_argument("--pred_features", type=int, default=1) |
| ap.add_argument("--map_lam", type=float, default=0.001) |
| ap.add_argument("--cooc_tokens", type=int, default=3_600_000) |
| ap.add_argument("--proj_tokens", type=int, default=3_600_000) |
| ap.add_argument("--fit_tokens", type=int, default=3_600_000) |
| ap.add_argument("--eval_tokens", type=int, default=159_631) |
| ap.add_argument("--chunk_docs", type=int, default=8) |
| ap.add_argument("--value_mode", default="dual_ridge_delta") |
| ap.add_argument("--ridge_lam", type=float, default=10.0) |
| ap.add_argument("--init_scale", type=float, default=0.05) |
| ap.add_argument("--steps", type=int, default=800) |
| ap.add_argument("--batch", type=int, default=2048) |
| ap.add_argument("--lr", type=float, default=0.003) |
| ap.add_argument("--wd", type=float, default=1e-4) |
| ap.add_argument("--eval_every", type=int, default=100) |
| ap.add_argument("--save", default="") |
| ap.add_argument("--resume", default="") |
| args = ap.parse_args() |
|
|
| import sentencepiece as spm |
| sp = spm.SentencePieceProcessor(model_file=args.spm_model) |
| eos = sp.eos_id(); V = sp.get_piece_size() |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| train = np.fromfile(args.train_bin, dtype=np.uint16) |
| valid = np.fromfile(args.valid_bin, dtype=np.uint16) |
| log("STREAM_CE device", device, "V", V, "train", len(train), "valid", len(valid)) |
| E = ppmi_embed(train, V, args.d, args.window, args.cooc_tokens, device) |
| Ps, Bs = [], [] |
| for _ in range(args.layers): |
| P, B = learn_map(train, E, Ps, Bs, eos, args, device) |
| Ps.append(P); Bs.append(B) |
|
|
| |
| A = G = None |
| for xnp in iter_chunks(train, eos, args, args.fit_tokens, shuffle=False): |
| X, y = chunk_features(xnp, E, Ps, Bs, eos, args, device) |
| if A is None: |
| D = X.shape[1] |
| A = torch.zeros((D, D), device=device, dtype=torch.float64) |
| G = torch.zeros((D, V), device=device, dtype=torch.float64) |
| Xd = X.double() |
| A += Xd.T @ Xd |
| G.index_add_(1, y, Xd.T) |
| diag = torch.trace(A) / A.shape[0] |
| W0 = torch.linalg.solve(A + args.ridge_lam * diag * torch.eye(A.shape[0], device=device, dtype=torch.float64), G).float() |
| W = (args.init_scale * W0).detach().clone() |
| b = torch.zeros(V, device=device) |
| if args.resume and os.path.exists(args.resume): |
| ck = torch.load(args.resume, map_location=device) |
| W = ck["W"].to(device) |
| b = ck["b"].to(device) |
| log("resumed", args.resume, "ppl", ck.get("ppl")) |
| W = W.requires_grad_(True) |
| b = b.requires_grad_(True) |
| opt = torch.optim.AdamW([W, b], lr=args.lr, weight_decay=args.wd) |
| log("init_eval_start D", W.shape[0]) |
| best = eval_ppl(valid, E, Ps, Bs, W, b, eos, args, device) |
| log(f"STREAM_CE init_ppl={best:.2f}") |
|
|
| step = 0 |
| while step < args.steps: |
| for xnp in iter_chunks(train, eos, args, args.fit_tokens, shuffle=True): |
| X, y = chunk_features(xnp, E, Ps, Bs, eos, args, device) |
| if len(y) == 0: |
| continue |
| idx = torch.randint(0, len(y), (min(args.batch, len(y)),), device=device) |
| loss = F.cross_entropy(X[idx] @ W + b, y[idx]) |
| opt.zero_grad(set_to_none=True); loss.backward(); opt.step() |
| step += 1 |
| if step % args.eval_every == 0: |
| ppl = eval_ppl(valid, E, Ps, Bs, W, b, eos, args, device) |
| if ppl < best: |
| best = ppl |
| if args.save: |
| torch.save({"W": W.detach().cpu(), "b": b.detach().cpu(), "ppl": best, "args": vars(args)}, args.save) |
| log(f"step={step} loss={float(loss):.4f} ppl={ppl:.2f} best={best:.2f}") |
| if step >= args.steps: |
| break |
| log(f"STREAM_CE best_ppl={best:.2f}") |
| if args.save: |
| torch.save({"W": W.detach().cpu(), "b": b.detach().cpu(), "ppl": best, "args": vars(args)}, args.save) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|