import argparse, os import numpy as np import torch import torch.nn.functional as F from torch_predictive_attn import ppmi_embed, learn_map, doc_index, apply_stack, features, log def build(args, device): import sentencepiece as spm sp = spm.SentencePieceProcessor(model_file=args.spm_model) eos = sp.eos_id(); V = sp.get_piece_size() train = np.fromfile(args.train_bin, dtype=np.uint16) E = ppmi_embed(train, V, args.d, args.window, args.cooc_tokens, device) Ps, Bs = [], [] for _ in range(args.layers): P, B = learn_map(train, E, Ps, Bs, eos, args, device) Ps.append(P); Bs.append(B) return sp, eos, E, Ps, Bs def next_logits(ids, E, Ps, Bs, W, b, eos, args, device): x = torch.tensor(ids, device=device, dtype=torch.long) _, within = doc_index(x, eos) H, phis = apply_stack(x, E, Ps, Bs, within, args) Phi = features(H, within, phis, args.extra_context)[-1:] return (Phi @ W + b).squeeze(0) def sample(logits, temp=0.8, top_k=40): logits = logits / max(temp, 1e-6) vals, idx = torch.topk(logits, min(top_k, logits.numel())) probs = F.softmax(vals, dim=-1) return int(idx[torch.multinomial(probs, 1)]) def main(): ap = argparse.ArgumentParser() ap.add_argument("--ckpt", default="/workspace/oneshot/logs/glm_d896_readout/stream_ce_lam10.pt") ap.add_argument("--spm_model", default="/workspace/glm/glm16k.model") ap.add_argument("--train_bin", default="/workspace/glm/glm_train.bin") ap.add_argument("--d", type=int, default=896) ap.add_argument("--r", type=int, default=320) ap.add_argument("--layers", type=int, default=10) ap.add_argument("--att_window", type=int, default=10) ap.add_argument("--temp", type=float, default=0.28) ap.add_argument("--window", type=int, default=10) ap.add_argument("--extra_context", type=int, default=1) ap.add_argument("--res_scale", type=float, default=0.07) ap.add_argument("--pred_scale", type=float, default=0.035) ap.add_argument("--pred_schedule", default="late") ap.add_argument("--orth_delta", type=int, default=1) ap.add_argument("--pred_norm", type=int, default=1) ap.add_argument("--pred_features", type=int, default=1) ap.add_argument("--map_lam", type=float, default=0.001) ap.add_argument("--cooc_tokens", type=int, default=3_600_000) ap.add_argument("--proj_tokens", type=int, default=3_600_000) ap.add_argument("--chunk_docs", type=int, default=8) ap.add_argument("--value_mode", default="dual_ridge_delta") ap.add_argument("--max_new", type=int, default=80) ap.add_argument("--sample_temp", type=float, default=0.8) ap.add_argument("--top_k", type=int, default=40) args = ap.parse_args() device = "cuda" if torch.cuda.is_available() else "cpu" sp, eos, E, Ps, Bs = build(args, device) ck = torch.load(args.ckpt, map_location=device) W = ck["W"].to(device); b = ck["b"].to(device) prompts = [ "Question: What is gravity?\nAnswer:", "Question: Why is the sky blue?\nAnswer:", "Question: Explain photosynthesis in simple words.\nAnswer:", "Question: If John has 3 apples and buys 2 more, how many apples does he have?\nAnswer:", "Question: Write a short friendly story about a robot learning to read.\nAnswer:", ] for p in prompts: ids = sp.encode(p) for _ in range(args.max_new): tok = sample(next_logits(ids, E, Ps, Bs, W, b, eos, args, device), args.sample_temp, args.top_k) ids.append(tok) if tok == eos: break print("=== PROMPT ===") print(p) print("=== OUTPUT ===") print(sp.decode(ids)) if __name__ == "__main__": main()