testoneshot / scripts /torch_ce_stream_readout.py
Asilarknes's picture
upload oneshot glm artifacts
a216fa7 verified
import argparse, math, os, random
import numpy as np
import torch
import torch.nn.functional as F
from torch_predictive_attn import ppmi_embed, learn_map, doc_index, apply_stack, features, log
def iter_chunks(tokens, eos, args, max_tokens, shuffle=False):
xnp = tokens[:max_tokens]
starts = np.flatnonzero(np.r_[True, xnp[:-1] == eos])
ids = list(range(0, len(starts), args.chunk_docs))
if shuffle:
random.shuffle(ids)
for i in ids:
lo = starts[i]
hi = starts[i + args.chunk_docs] if i + args.chunk_docs < len(starts) else len(xnp)
yield xnp[lo:hi]
def chunk_features(xnp, E, Ps, Bs, eos, args, device):
x = torch.tensor(xnp.astype(np.int64), device=device)
seg, within = doc_index(x, eos)
H, phis = apply_stack(x, E, Ps, Bs, within, args)
Phi = features(H, within, phis, args.extra_context)
y = torch.empty(len(x), device=device, dtype=torch.long)
y[:-1] = x[1:]; y[-1] = eos
m = torch.ones(len(x), device=device, dtype=torch.bool)
m[-1] = False; m[:-1] &= seg[1:].eq(seg[:-1]); m &= x.ne(eos)
return Phi[m].float(), y[m]
def eval_ppl(tokens, E, Ps, Bs, W, b, eos, args, device):
nll = 0.0; n = 0
with torch.no_grad():
for xnp in iter_chunks(tokens, eos, args, args.eval_tokens, shuffle=False):
X, y = chunk_features(xnp, E, Ps, Bs, eos, args, device)
for i in range(0, len(y), args.batch):
logits = X[i:i+args.batch] @ W + b
nll += float(F.cross_entropy(logits, y[i:i+args.batch], reduction="sum"))
n += len(y[i:i+args.batch])
return math.exp(nll / max(1, n))
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--data", default="/workspace/glm")
ap.add_argument("--spm_model", default="/workspace/glm/glm16k.model")
ap.add_argument("--train_bin", default="/workspace/glm/glm_train.bin")
ap.add_argument("--valid_bin", default="/workspace/glm/glm_valid.bin")
ap.add_argument("--vocab", type=int, default=8192)
ap.add_argument("--d", type=int, default=896)
ap.add_argument("--r", type=int, default=320)
ap.add_argument("--layers", type=int, default=10)
ap.add_argument("--att_window", type=int, default=10)
ap.add_argument("--temp", type=float, default=0.28)
ap.add_argument("--window", type=int, default=10)
ap.add_argument("--extra_context", type=int, default=1)
ap.add_argument("--res_scale", type=float, default=0.07)
ap.add_argument("--pred_scale", type=float, default=0.035)
ap.add_argument("--pred_schedule", default="late")
ap.add_argument("--orth_delta", type=int, default=1)
ap.add_argument("--pred_norm", type=int, default=1)
ap.add_argument("--pred_features", type=int, default=1)
ap.add_argument("--map_lam", type=float, default=0.001)
ap.add_argument("--cooc_tokens", type=int, default=3_600_000)
ap.add_argument("--proj_tokens", type=int, default=3_600_000)
ap.add_argument("--fit_tokens", type=int, default=3_600_000)
ap.add_argument("--eval_tokens", type=int, default=159_631)
ap.add_argument("--chunk_docs", type=int, default=8)
ap.add_argument("--value_mode", default="dual_ridge_delta")
ap.add_argument("--ridge_lam", type=float, default=10.0)
ap.add_argument("--init_scale", type=float, default=0.05)
ap.add_argument("--steps", type=int, default=800)
ap.add_argument("--batch", type=int, default=2048)
ap.add_argument("--lr", type=float, default=0.003)
ap.add_argument("--wd", type=float, default=1e-4)
ap.add_argument("--eval_every", type=int, default=100)
ap.add_argument("--save", default="")
ap.add_argument("--resume", default="")
args = ap.parse_args()
import sentencepiece as spm
sp = spm.SentencePieceProcessor(model_file=args.spm_model)
eos = sp.eos_id(); V = sp.get_piece_size()
device = "cuda" if torch.cuda.is_available() else "cpu"
train = np.fromfile(args.train_bin, dtype=np.uint16)
valid = np.fromfile(args.valid_bin, dtype=np.uint16)
log("STREAM_CE device", device, "V", V, "train", len(train), "valid", len(valid))
E = ppmi_embed(train, V, args.d, args.window, args.cooc_tokens, device)
Ps, Bs = [], []
for _ in range(args.layers):
P, B = learn_map(train, E, Ps, Bs, eos, args, device)
Ps.append(P); Bs.append(B)
# Build ridge init streaming stats only.
A = G = None
for xnp in iter_chunks(train, eos, args, args.fit_tokens, shuffle=False):
X, y = chunk_features(xnp, E, Ps, Bs, eos, args, device)
if A is None:
D = X.shape[1]
A = torch.zeros((D, D), device=device, dtype=torch.float64)
G = torch.zeros((D, V), device=device, dtype=torch.float64)
Xd = X.double()
A += Xd.T @ Xd
G.index_add_(1, y, Xd.T)
diag = torch.trace(A) / A.shape[0]
W0 = torch.linalg.solve(A + args.ridge_lam * diag * torch.eye(A.shape[0], device=device, dtype=torch.float64), G).float()
W = (args.init_scale * W0).detach().clone()
b = torch.zeros(V, device=device)
if args.resume and os.path.exists(args.resume):
ck = torch.load(args.resume, map_location=device)
W = ck["W"].to(device)
b = ck["b"].to(device)
log("resumed", args.resume, "ppl", ck.get("ppl"))
W = W.requires_grad_(True)
b = b.requires_grad_(True)
opt = torch.optim.AdamW([W, b], lr=args.lr, weight_decay=args.wd)
log("init_eval_start D", W.shape[0])
best = eval_ppl(valid, E, Ps, Bs, W, b, eos, args, device)
log(f"STREAM_CE init_ppl={best:.2f}")
step = 0
while step < args.steps:
for xnp in iter_chunks(train, eos, args, args.fit_tokens, shuffle=True):
X, y = chunk_features(xnp, E, Ps, Bs, eos, args, device)
if len(y) == 0:
continue
idx = torch.randint(0, len(y), (min(args.batch, len(y)),), device=device)
loss = F.cross_entropy(X[idx] @ W + b, y[idx])
opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
step += 1
if step % args.eval_every == 0:
ppl = eval_ppl(valid, E, Ps, Bs, W, b, eos, args, device)
if ppl < best:
best = ppl
if args.save:
torch.save({"W": W.detach().cpu(), "b": b.detach().cpu(), "ppl": best, "args": vars(args)}, args.save)
log(f"step={step} loss={float(loss):.4f} ppl={ppl:.2f} best={best:.2f}")
if step >= args.steps:
break
log(f"STREAM_CE best_ppl={best:.2f}")
if args.save:
torch.save({"W": W.detach().cpu(), "b": b.detach().cpu(), "ppl": best, "args": vars(args)}, args.save)
if __name__ == "__main__":
main()