File size: 3,747 Bytes
a216fa7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import argparse, os
import numpy as np
import torch
import torch.nn.functional as F
from torch_predictive_attn import ppmi_embed, learn_map, doc_index, apply_stack, features, log


def build(args, device):
    import sentencepiece as spm
    sp = spm.SentencePieceProcessor(model_file=args.spm_model)
    eos = sp.eos_id(); V = sp.get_piece_size()
    train = np.fromfile(args.train_bin, dtype=np.uint16)
    E = ppmi_embed(train, V, args.d, args.window, args.cooc_tokens, device)
    Ps, Bs = [], []
    for _ in range(args.layers):
        P, B = learn_map(train, E, Ps, Bs, eos, args, device)
        Ps.append(P); Bs.append(B)
    return sp, eos, E, Ps, Bs


def next_logits(ids, E, Ps, Bs, W, b, eos, args, device):
    x = torch.tensor(ids, device=device, dtype=torch.long)
    _, within = doc_index(x, eos)
    H, phis = apply_stack(x, E, Ps, Bs, within, args)
    Phi = features(H, within, phis, args.extra_context)[-1:]
    return (Phi @ W + b).squeeze(0)


def sample(logits, temp=0.8, top_k=40):
    logits = logits / max(temp, 1e-6)
    vals, idx = torch.topk(logits, min(top_k, logits.numel()))
    probs = F.softmax(vals, dim=-1)
    return int(idx[torch.multinomial(probs, 1)])


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--ckpt", default="/workspace/oneshot/logs/glm_d896_readout/stream_ce_lam10.pt")
    ap.add_argument("--spm_model", default="/workspace/glm/glm16k.model")
    ap.add_argument("--train_bin", default="/workspace/glm/glm_train.bin")
    ap.add_argument("--d", type=int, default=896)
    ap.add_argument("--r", type=int, default=320)
    ap.add_argument("--layers", type=int, default=10)
    ap.add_argument("--att_window", type=int, default=10)
    ap.add_argument("--temp", type=float, default=0.28)
    ap.add_argument("--window", type=int, default=10)
    ap.add_argument("--extra_context", type=int, default=1)
    ap.add_argument("--res_scale", type=float, default=0.07)
    ap.add_argument("--pred_scale", type=float, default=0.035)
    ap.add_argument("--pred_schedule", default="late")
    ap.add_argument("--orth_delta", type=int, default=1)
    ap.add_argument("--pred_norm", type=int, default=1)
    ap.add_argument("--pred_features", type=int, default=1)
    ap.add_argument("--map_lam", type=float, default=0.001)
    ap.add_argument("--cooc_tokens", type=int, default=3_600_000)
    ap.add_argument("--proj_tokens", type=int, default=3_600_000)
    ap.add_argument("--chunk_docs", type=int, default=8)
    ap.add_argument("--value_mode", default="dual_ridge_delta")
    ap.add_argument("--max_new", type=int, default=80)
    ap.add_argument("--sample_temp", type=float, default=0.8)
    ap.add_argument("--top_k", type=int, default=40)
    args = ap.parse_args()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    sp, eos, E, Ps, Bs = build(args, device)
    ck = torch.load(args.ckpt, map_location=device)
    W = ck["W"].to(device); b = ck["b"].to(device)
    prompts = [
        "Question: What is gravity?\nAnswer:",
        "Question: Why is the sky blue?\nAnswer:",
        "Question: Explain photosynthesis in simple words.\nAnswer:",
        "Question: If John has 3 apples and buys 2 more, how many apples does he have?\nAnswer:",
        "Question: Write a short friendly story about a robot learning to read.\nAnswer:",
    ]
    for p in prompts:
        ids = sp.encode(p)
        for _ in range(args.max_new):
            tok = sample(next_logits(ids, E, Ps, Bs, W, b, eos, args, device), args.sample_temp, args.top_k)
            ids.append(tok)
            if tok == eos:
                break
        print("=== PROMPT ===")
        print(p)
        print("=== OUTPUT ===")
        print(sp.decode(ids))


if __name__ == "__main__":
    main()