File size: 3,747 Bytes
a216fa7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 | import argparse, os
import numpy as np
import torch
import torch.nn.functional as F
from torch_predictive_attn import ppmi_embed, learn_map, doc_index, apply_stack, features, log
def build(args, device):
import sentencepiece as spm
sp = spm.SentencePieceProcessor(model_file=args.spm_model)
eos = sp.eos_id(); V = sp.get_piece_size()
train = np.fromfile(args.train_bin, dtype=np.uint16)
E = ppmi_embed(train, V, args.d, args.window, args.cooc_tokens, device)
Ps, Bs = [], []
for _ in range(args.layers):
P, B = learn_map(train, E, Ps, Bs, eos, args, device)
Ps.append(P); Bs.append(B)
return sp, eos, E, Ps, Bs
def next_logits(ids, E, Ps, Bs, W, b, eos, args, device):
x = torch.tensor(ids, device=device, dtype=torch.long)
_, within = doc_index(x, eos)
H, phis = apply_stack(x, E, Ps, Bs, within, args)
Phi = features(H, within, phis, args.extra_context)[-1:]
return (Phi @ W + b).squeeze(0)
def sample(logits, temp=0.8, top_k=40):
logits = logits / max(temp, 1e-6)
vals, idx = torch.topk(logits, min(top_k, logits.numel()))
probs = F.softmax(vals, dim=-1)
return int(idx[torch.multinomial(probs, 1)])
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--ckpt", default="/workspace/oneshot/logs/glm_d896_readout/stream_ce_lam10.pt")
ap.add_argument("--spm_model", default="/workspace/glm/glm16k.model")
ap.add_argument("--train_bin", default="/workspace/glm/glm_train.bin")
ap.add_argument("--d", type=int, default=896)
ap.add_argument("--r", type=int, default=320)
ap.add_argument("--layers", type=int, default=10)
ap.add_argument("--att_window", type=int, default=10)
ap.add_argument("--temp", type=float, default=0.28)
ap.add_argument("--window", type=int, default=10)
ap.add_argument("--extra_context", type=int, default=1)
ap.add_argument("--res_scale", type=float, default=0.07)
ap.add_argument("--pred_scale", type=float, default=0.035)
ap.add_argument("--pred_schedule", default="late")
ap.add_argument("--orth_delta", type=int, default=1)
ap.add_argument("--pred_norm", type=int, default=1)
ap.add_argument("--pred_features", type=int, default=1)
ap.add_argument("--map_lam", type=float, default=0.001)
ap.add_argument("--cooc_tokens", type=int, default=3_600_000)
ap.add_argument("--proj_tokens", type=int, default=3_600_000)
ap.add_argument("--chunk_docs", type=int, default=8)
ap.add_argument("--value_mode", default="dual_ridge_delta")
ap.add_argument("--max_new", type=int, default=80)
ap.add_argument("--sample_temp", type=float, default=0.8)
ap.add_argument("--top_k", type=int, default=40)
args = ap.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
sp, eos, E, Ps, Bs = build(args, device)
ck = torch.load(args.ckpt, map_location=device)
W = ck["W"].to(device); b = ck["b"].to(device)
prompts = [
"Question: What is gravity?\nAnswer:",
"Question: Why is the sky blue?\nAnswer:",
"Question: Explain photosynthesis in simple words.\nAnswer:",
"Question: If John has 3 apples and buys 2 more, how many apples does he have?\nAnswer:",
"Question: Write a short friendly story about a robot learning to read.\nAnswer:",
]
for p in prompts:
ids = sp.encode(p)
for _ in range(args.max_new):
tok = sample(next_logits(ids, E, Ps, Bs, W, b, eos, args, device), args.sample_temp, args.top_k)
ids.append(tok)
if tok == eos:
break
print("=== PROMPT ===")
print(p)
print("=== OUTPUT ===")
print(sp.decode(ids))
if __name__ == "__main__":
main()
|