import argparse, math, os, random import numpy as np import torch import torch.nn.functional as F from torch_predictive_attn import ppmi_embed, learn_map, doc_index, apply_stack, features, log def iter_chunks(tokens, eos, args, max_tokens, shuffle=False): xnp = tokens[:max_tokens] starts = np.flatnonzero(np.r_[True, xnp[:-1] == eos]) ids = list(range(0, len(starts), args.chunk_docs)) if shuffle: random.shuffle(ids) for i in ids: lo = starts[i] hi = starts[i + args.chunk_docs] if i + args.chunk_docs < len(starts) else len(xnp) yield xnp[lo:hi] def chunk_features(xnp, E, Ps, Bs, eos, args, device): x = torch.tensor(xnp.astype(np.int64), device=device) seg, within = doc_index(x, eos) H, phis = apply_stack(x, E, Ps, Bs, within, args) Phi = features(H, within, phis, args.extra_context) y = torch.empty(len(x), device=device, dtype=torch.long) y[:-1] = x[1:]; y[-1] = eos m = torch.ones(len(x), device=device, dtype=torch.bool) m[-1] = False; m[:-1] &= seg[1:].eq(seg[:-1]); m &= x.ne(eos) return Phi[m].float(), y[m] def eval_ppl(tokens, E, Ps, Bs, W, b, eos, args, device): nll = 0.0; n = 0 with torch.no_grad(): for xnp in iter_chunks(tokens, eos, args, args.eval_tokens, shuffle=False): X, y = chunk_features(xnp, E, Ps, Bs, eos, args, device) for i in range(0, len(y), args.batch): logits = X[i:i+args.batch] @ W + b nll += float(F.cross_entropy(logits, y[i:i+args.batch], reduction="sum")) n += len(y[i:i+args.batch]) return math.exp(nll / max(1, n)) def main(): ap = argparse.ArgumentParser() ap.add_argument("--data", default="/workspace/glm") ap.add_argument("--spm_model", default="/workspace/glm/glm16k.model") ap.add_argument("--train_bin", default="/workspace/glm/glm_train.bin") ap.add_argument("--valid_bin", default="/workspace/glm/glm_valid.bin") ap.add_argument("--vocab", type=int, default=8192) ap.add_argument("--d", type=int, default=896) ap.add_argument("--r", type=int, default=320) ap.add_argument("--layers", type=int, default=10) ap.add_argument("--att_window", type=int, default=10) ap.add_argument("--temp", type=float, default=0.28) ap.add_argument("--window", type=int, default=10) ap.add_argument("--extra_context", type=int, default=1) ap.add_argument("--res_scale", type=float, default=0.07) ap.add_argument("--pred_scale", type=float, default=0.035) ap.add_argument("--pred_schedule", default="late") ap.add_argument("--orth_delta", type=int, default=1) ap.add_argument("--pred_norm", type=int, default=1) ap.add_argument("--pred_features", type=int, default=1) ap.add_argument("--map_lam", type=float, default=0.001) ap.add_argument("--cooc_tokens", type=int, default=3_600_000) ap.add_argument("--proj_tokens", type=int, default=3_600_000) ap.add_argument("--fit_tokens", type=int, default=3_600_000) ap.add_argument("--eval_tokens", type=int, default=159_631) ap.add_argument("--chunk_docs", type=int, default=8) ap.add_argument("--value_mode", default="dual_ridge_delta") ap.add_argument("--ridge_lam", type=float, default=10.0) ap.add_argument("--init_scale", type=float, default=0.05) ap.add_argument("--steps", type=int, default=800) ap.add_argument("--batch", type=int, default=2048) ap.add_argument("--lr", type=float, default=0.003) ap.add_argument("--wd", type=float, default=1e-4) ap.add_argument("--eval_every", type=int, default=100) ap.add_argument("--save", default="") ap.add_argument("--resume", default="") args = ap.parse_args() import sentencepiece as spm sp = spm.SentencePieceProcessor(model_file=args.spm_model) eos = sp.eos_id(); V = sp.get_piece_size() device = "cuda" if torch.cuda.is_available() else "cpu" train = np.fromfile(args.train_bin, dtype=np.uint16) valid = np.fromfile(args.valid_bin, dtype=np.uint16) log("STREAM_CE device", device, "V", V, "train", len(train), "valid", len(valid)) E = ppmi_embed(train, V, args.d, args.window, args.cooc_tokens, device) Ps, Bs = [], [] for _ in range(args.layers): P, B = learn_map(train, E, Ps, Bs, eos, args, device) Ps.append(P); Bs.append(B) # Build ridge init streaming stats only. A = G = None for xnp in iter_chunks(train, eos, args, args.fit_tokens, shuffle=False): X, y = chunk_features(xnp, E, Ps, Bs, eos, args, device) if A is None: D = X.shape[1] A = torch.zeros((D, D), device=device, dtype=torch.float64) G = torch.zeros((D, V), device=device, dtype=torch.float64) Xd = X.double() A += Xd.T @ Xd G.index_add_(1, y, Xd.T) diag = torch.trace(A) / A.shape[0] W0 = torch.linalg.solve(A + args.ridge_lam * diag * torch.eye(A.shape[0], device=device, dtype=torch.float64), G).float() W = (args.init_scale * W0).detach().clone() b = torch.zeros(V, device=device) if args.resume and os.path.exists(args.resume): ck = torch.load(args.resume, map_location=device) W = ck["W"].to(device) b = ck["b"].to(device) log("resumed", args.resume, "ppl", ck.get("ppl")) W = W.requires_grad_(True) b = b.requires_grad_(True) opt = torch.optim.AdamW([W, b], lr=args.lr, weight_decay=args.wd) log("init_eval_start D", W.shape[0]) best = eval_ppl(valid, E, Ps, Bs, W, b, eos, args, device) log(f"STREAM_CE init_ppl={best:.2f}") step = 0 while step < args.steps: for xnp in iter_chunks(train, eos, args, args.fit_tokens, shuffle=True): X, y = chunk_features(xnp, E, Ps, Bs, eos, args, device) if len(y) == 0: continue idx = torch.randint(0, len(y), (min(args.batch, len(y)),), device=device) loss = F.cross_entropy(X[idx] @ W + b, y[idx]) opt.zero_grad(set_to_none=True); loss.backward(); opt.step() step += 1 if step % args.eval_every == 0: ppl = eval_ppl(valid, E, Ps, Bs, W, b, eos, args, device) if ppl < best: best = ppl if args.save: torch.save({"W": W.detach().cpu(), "b": b.detach().cpu(), "ppl": best, "args": vars(args)}, args.save) log(f"step={step} loss={float(loss):.4f} ppl={ppl:.2f} best={best:.2f}") if step >= args.steps: break log(f"STREAM_CE best_ppl={best:.2f}") if args.save: torch.save({"W": W.detach().cpu(), "b": b.detach().cpu(), "ppl": best, "args": vars(args)}, args.save) if __name__ == "__main__": main()