| import argparse, os |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from torch_predictive_attn import ppmi_embed, learn_map, doc_index, apply_stack, features, log |
|
|
|
|
| def build(args, device): |
| import sentencepiece as spm |
| sp = spm.SentencePieceProcessor(model_file=args.spm_model) |
| eos = sp.eos_id(); V = sp.get_piece_size() |
| train = np.fromfile(args.train_bin, dtype=np.uint16) |
| E = ppmi_embed(train, V, args.d, args.window, args.cooc_tokens, device) |
| Ps, Bs = [], [] |
| for _ in range(args.layers): |
| P, B = learn_map(train, E, Ps, Bs, eos, args, device) |
| Ps.append(P); Bs.append(B) |
| return sp, eos, E, Ps, Bs |
|
|
|
|
| def next_logits(ids, E, Ps, Bs, W, b, eos, args, device): |
| x = torch.tensor(ids, device=device, dtype=torch.long) |
| _, within = doc_index(x, eos) |
| H, phis = apply_stack(x, E, Ps, Bs, within, args) |
| Phi = features(H, within, phis, args.extra_context)[-1:] |
| return (Phi @ W + b).squeeze(0) |
|
|
|
|
| def sample(logits, temp=0.8, top_k=40): |
| logits = logits / max(temp, 1e-6) |
| vals, idx = torch.topk(logits, min(top_k, logits.numel())) |
| probs = F.softmax(vals, dim=-1) |
| return int(idx[torch.multinomial(probs, 1)]) |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--ckpt", default="/workspace/oneshot/logs/glm_d896_readout/stream_ce_lam10.pt") |
| ap.add_argument("--spm_model", default="/workspace/glm/glm16k.model") |
| ap.add_argument("--train_bin", default="/workspace/glm/glm_train.bin") |
| ap.add_argument("--d", type=int, default=896) |
| ap.add_argument("--r", type=int, default=320) |
| ap.add_argument("--layers", type=int, default=10) |
| ap.add_argument("--att_window", type=int, default=10) |
| ap.add_argument("--temp", type=float, default=0.28) |
| ap.add_argument("--window", type=int, default=10) |
| ap.add_argument("--extra_context", type=int, default=1) |
| ap.add_argument("--res_scale", type=float, default=0.07) |
| ap.add_argument("--pred_scale", type=float, default=0.035) |
| ap.add_argument("--pred_schedule", default="late") |
| ap.add_argument("--orth_delta", type=int, default=1) |
| ap.add_argument("--pred_norm", type=int, default=1) |
| ap.add_argument("--pred_features", type=int, default=1) |
| ap.add_argument("--map_lam", type=float, default=0.001) |
| ap.add_argument("--cooc_tokens", type=int, default=3_600_000) |
| ap.add_argument("--proj_tokens", type=int, default=3_600_000) |
| ap.add_argument("--chunk_docs", type=int, default=8) |
| ap.add_argument("--value_mode", default="dual_ridge_delta") |
| ap.add_argument("--max_new", type=int, default=80) |
| ap.add_argument("--sample_temp", type=float, default=0.8) |
| ap.add_argument("--top_k", type=int, default=40) |
| args = ap.parse_args() |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| sp, eos, E, Ps, Bs = build(args, device) |
| ck = torch.load(args.ckpt, map_location=device) |
| W = ck["W"].to(device); b = ck["b"].to(device) |
| prompts = [ |
| "Question: What is gravity?\nAnswer:", |
| "Question: Why is the sky blue?\nAnswer:", |
| "Question: Explain photosynthesis in simple words.\nAnswer:", |
| "Question: If John has 3 apples and buys 2 more, how many apples does he have?\nAnswer:", |
| "Question: Write a short friendly story about a robot learning to read.\nAnswer:", |
| ] |
| for p in prompts: |
| ids = sp.encode(p) |
| for _ in range(args.max_new): |
| tok = sample(next_logits(ids, E, Ps, Bs, W, b, eos, args, device), args.sample_temp, args.top_k) |
| ids.append(tok) |
| if tok == eos: |
| break |
| print("=== PROMPT ===") |
| print(p) |
| print("=== OUTPUT ===") |
| print(sp.decode(ids)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|