MarxistLeninist commited on
Commit
6393d0d
·
verified ·
1 Parent(s): 27db0fa

Upload a3.py

Browse files

inference code (fix the nat inferencing using chatbot but ar works tho)

Files changed (1) hide show
  1. a3.py +318 -0
a3.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ a2.py – joint-train low-rank AR + NAT, auto-resume at epoch 84
4
+
5
+ • Loads ar_ep084.pt & nat_ep084.pt from ckpts1/ if present, then trains
6
+ from epoch 85. Otherwise starts from scratch.
7
+ • Dataset: WikiText-103 (raw) streamed, default cap = 100 M tokens.
8
+ • Checkpoints: epoch 1, every 5 epochs, and final.
9
+ • Default preset = small (fits 11 GB GPUs).
10
+ """
11
+
12
+ from __future__ import annotations
13
+ import argparse, math, pathlib, time
14
+ from contextlib import nullcontext
15
+
16
+ import torch, torch.nn as nn
17
+ from torch.utils.data import DataLoader, IterableDataset
18
+ from datasets import load_dataset
19
+ from tqdm.auto import tqdm
20
+ from transformers import AutoTokenizer, logging as hf_log
21
+
22
+ # ╭─ AMP shim ─╮
23
+ try:
24
+ from torch.amp import autocast as _ac_new
25
+ from torch.amp import GradScaler
26
+ _AMP = "new"
27
+ except ImportError: # torch < 2.2
28
+ from torch.cuda.amp import autocast as _ac_old
29
+ from torch.cuda.amp import GradScaler
30
+ _AMP = "old"
31
+
32
+ def amp(enabled, dtype, device="cuda"):
33
+ if not enabled:
34
+ return nullcontext()
35
+ return _ac_new(device_type=device, dtype=dtype) if _AMP == "new" else _ac_old(dtype=dtype)
36
+ # ╰─────────────╯
37
+
38
+ hf_log.set_verbosity_error()
39
+ torch.backends.cuda.matmul.allow_tf32 = True # free speed-up on Ampere+
40
+
41
+ # ───────────── presets ─────────────
42
+ PRESETS = {
43
+ "small": dict(ar_d=512, ar_layers=8, ar_heads=16,
44
+ nat_d=640, nat_layers=12, nat_heads=20),
45
+ "base": dict(ar_d=768, ar_layers=12, ar_heads=24,
46
+ nat_d=1024,nat_layers=16, nat_heads=32),
47
+ "large": dict(ar_d=1024, ar_layers=16, ar_heads=32,
48
+ nat_d=1280,nat_layers=24, nat_heads=40),
49
+ }
50
+ BLOCK = 128
51
+ DROP_P = 0.1
52
+ LR_AR = LR_NAT = 2e-4
53
+ ALPHA_KL = 1.0
54
+ CKDIR = pathlib.Path("ckpts1")
55
+ DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
+ SAVE_EVERY = 5
57
+ RESUME_EPOCH= 84 # ← hard-coded resume point
58
+
59
+ # ───────────── tokenizer ─────────────
60
+ tok = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-0528", use_fast=True)
61
+ if tok.pad_token is None:
62
+ tok.add_special_tokens({"pad_token": "[PAD]"})
63
+ BLANK_ID = tok.pad_token_id
64
+ VOCAB = max(tok.get_vocab().values()) + 1
65
+
66
+ # ───────────── data streaming ─────────────
67
+ def stream_wikitext(max_tokens=0):
68
+ """Yield tokens from WikiText-103 until *max_tokens* reached (0 = no cap)."""
69
+ n = 0
70
+ for ex in load_dataset("wikitext", "wikitext-103-raw-v1",
71
+ split="train", streaming=True):
72
+ for t in tok.encode(ex["text"]):
73
+ yield t
74
+ n += 1
75
+ if max_tokens and n >= max_tokens:
76
+ return
77
+
78
+ class ARDataset(IterableDataset):
79
+ def __init__(self, blk, max_tokens=0):
80
+ self.blk, self.max = blk, max_tokens
81
+ def __iter__(self):
82
+ buf, gen = [], stream_wikitext(self.max)
83
+ for t in gen:
84
+ buf.append(t)
85
+ while len(buf) > self.blk:
86
+ yield torch.tensor(buf[:self.blk]), torch.tensor(buf[1:self.blk+1])
87
+ buf = buf[1:]
88
+
89
+ class NATDataset(IterableDataset):
90
+ def __init__(self, blk, max_tokens=0):
91
+ self.blk, self.max = blk, max_tokens
92
+ def __iter__(self):
93
+ buf, gen = [], stream_wikitext(self.max)
94
+ for t in gen:
95
+ buf.append(t)
96
+ while len(buf) >= self.blk:
97
+ tgt, buf = buf[:self.blk], buf[self.blk:]
98
+ inp = [BLANK_ID if i % 2 == 0 else tgt[i//2]
99
+ for i in range(self.blk * 2)]
100
+ yield torch.tensor(inp), torch.tensor(tgt)
101
+
102
+ # ───────────── transformer components ─────────────
103
+ class LowRankMHA(nn.Module):
104
+ def __init__(self, d, h, r):
105
+ super().__init__()
106
+ self.h, self.dk = h, d // h
107
+ self.q = self.k = self.v = nn.Linear(d, d, bias=False)
108
+ self.U = nn.Parameter(torch.randn(self.dk, r)); nn.init.orthogonal_(self.U)
109
+ self.proj = nn.Linear(h * r, d, bias=False)
110
+ self.drop = nn.Dropout(DROP_P)
111
+
112
+ def _proj(self, x):
113
+ B, N, _ = x.shape
114
+ return (x.view(B, N, self.h, self.dk).transpose(1, 2) @ self.U)
115
+
116
+ def forward(self, x, mask=None):
117
+ q, k, v = map(self._proj, (self.q(x), self.k(x), self.v(x)))
118
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
119
+ if mask is not None:
120
+ att += mask
121
+ out = (att.softmax(-1) @ v).transpose(1, 2).reshape(x.size(0), x.size(1), -1)
122
+ return self.drop(self.proj(out))
123
+
124
+ class Block(nn.Module):
125
+ def __init__(self, d, h, dff, r):
126
+ super().__init__()
127
+ self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
128
+ self.mha = LowRankMHA(d, h, r)
129
+ self.ff = nn.Sequential(
130
+ nn.Linear(d, dff), nn.ReLU(), nn.Dropout(DROP_P), nn.Linear(dff, d)
131
+ )
132
+ def forward(self, x, mask=None):
133
+ y = self.ln1(x)
134
+ x = x + self.mha(y, mask)
135
+ return x + self.ff(self.ln2(x))
136
+
137
+ # ───────────── model builders ─────────────
138
+ def make_transformer(d, n_layers, n_heads, vocab, max_len=8192):
139
+ dff = 4 * d; low_rank = max(32, d // 16)
140
+ m = nn.Module()
141
+ m.emb = nn.Embedding(vocab, d)
142
+ m.pos = nn.Embedding(max_len, d)
143
+ m.blocks = nn.ModuleList(Block(d, n_heads, dff, low_rank)
144
+ for _ in range(n_layers))
145
+ m.ln = nn.LayerNorm(d)
146
+ m.out = nn.Linear(d, vocab)
147
+ return m
148
+
149
+ def make_ar(cfg): return make_transformer(cfg["ar_d"], cfg["ar_layers"],
150
+ cfg["ar_heads"], VOCAB, 4096)
151
+ def make_nat(cfg): return make_transformer(cfg["nat_d"], cfg["nat_layers"],
152
+ cfg["nat_heads"], VOCAB, 8192)
153
+
154
+ # ───────────── NAT helpers ─────────────
155
+ class NATWrap(nn.Module):
156
+ def __init__(self, core): super().__init__(); self.core = core
157
+ def forward(self, x): return self.core(torch.repeat_interleave(x, 2, 1))
158
+
159
+ class ParScale(nn.Module):
160
+ def __init__(self, nat, P): super().__init__(); self.nat,self.P = nat,P
161
+ @torch.no_grad()
162
+ def generate(self, x, passes=1):
163
+ for _ in range(passes):
164
+ logits = self.nat(x); logits[..., BLANK_ID] = -1e9
165
+ cand = logits.topk(self.P, -1).indices.permute(2,0,1)
166
+ best = (cand != BLANK_ID).float().mean(-1).argmax(0)
167
+ x = cand[best, torch.arange(x.size(0), device=x.device)][:, ::2]
168
+ return x
169
+
170
+ # ───────────── helpers ─────────────
171
+ def fwd(model, ids, causal=False):
172
+ B, N = ids.shape
173
+ x = model.emb(ids) + model.pos(torch.arange(N, device=ids.device))
174
+ mask = None
175
+ if causal:
176
+ mask = torch.triu(torch.full((1,1,N,N), float("-inf"),
177
+ device=ids.device), 1)
178
+ for blk in model.blocks:
179
+ x = blk(x, mask)
180
+ return model.out(model.ln(x))
181
+
182
+ # ───────────── training ─────────────
183
+ def train_joint(a):
184
+ cfg = PRESETS[a.preset]
185
+ ar_loader = DataLoader(ARDataset(BLOCK, a.max_tokens), batch_size=a.batch)
186
+ nat_loader = DataLoader(NATDataset(BLOCK, a.max_tokens), batch_size=a.batch)
187
+ ar , nat = make_ar(cfg).to(DEV), make_nat(cfg).to(DEV)
188
+
189
+ # ----- resume if we have epoch-84 weights -----
190
+ start_ep = 0
191
+ ck_ar = CKDIR / f"ar_ep{RESUME_EPOCH:03d}.pt"
192
+ ck_nat = CKDIR / f"nat_ep{RESUME_EPOCH:03d}.pt"
193
+ if ck_ar.exists() and ck_nat.exists():
194
+ ar.load_state_dict(torch.load(ck_ar, map_location=DEV))
195
+ nat.load_state_dict(torch.load(ck_nat, map_location=DEV))
196
+ start_ep = RESUME_EPOCH
197
+ print(f"Resuming from epoch {start_ep} checkpoints.")
198
+
199
+ opt = torch.optim.AdamW(
200
+ [{"params": ar.parameters(), "lr": LR_AR},
201
+ {"params": nat.parameters(), "lr": LR_NAT}]
202
+ )
203
+
204
+ # >>>>>>> FIX: ensure 'initial_lr' so scheduler can resume <<<<<<<
205
+ for pg in opt.param_groups:
206
+ pg.setdefault("initial_lr", pg["lr"])
207
+ # ------------------------------------------------------------------
208
+
209
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(
210
+ opt, T_max=a.epochs, last_epoch=start_ep - 1
211
+ )
212
+
213
+ ce = nn.CrossEntropyLoss(label_smoothing=0.1)
214
+ ctc = nn.CTCLoss(blank=BLANK_ID, zero_infinity=True)
215
+ kl = nn.KLDivLoss(reduction="batchmean")
216
+
217
+ use_amp = DEV.type == "cuda" and a.amp
218
+ scaler = GradScaler(enabled=use_amp)
219
+ cast_dt = torch.bfloat16 if use_amp else torch.float32
220
+ CKDIR.mkdir(exist_ok=True)
221
+ tot_batches = None if not a.max_tokens else math.ceil(
222
+ math.ceil(a.max_tokens / BLOCK) / a.batch)
223
+
224
+ for ep in range(start_ep + 1, a.epochs + 1):
225
+ ar.train(); nat.train(); tot = steps = 0
226
+ loop = tqdm(zip(ar_loader, nat_loader), total=tot_batches,
227
+ desc=f"Epoch {ep}/{a.epochs}", unit="batch")
228
+ for (x_ar, y_ar), (x_nat, y_nat) in loop:
229
+ x_ar, y_ar, x_nat, y_nat = map(lambda t: t.to(DEV),
230
+ (x_ar, y_ar, x_nat, y_nat))
231
+ opt.zero_grad(set_to_none=True)
232
+ with amp(use_amp, cast_dt, DEV.type):
233
+ logits_ar = fwd(ar, x_ar, causal=True)
234
+ loss_ar = ce(logits_ar.reshape(-1, VOCAB), y_ar.reshape(-1))
235
+
236
+ logp_nat = fwd(nat, x_nat).log_softmax(-1).transpose(0, 1)
237
+ ilen=tlen = torch.full((x_nat.size(0),), x_nat.size(1)//2,
238
+ dtype=torch.long, device=DEV)
239
+ loss_nat = ctc(logp_nat, y_nat, ilen, tlen)
240
+
241
+ loss_kld = kl(fwd(nat, x_ar).log_softmax(-1),
242
+ logits_ar.softmax(-1).detach())
243
+
244
+ loss = loss_ar + loss_nat + ALPHA_KL * loss_kld
245
+
246
+ scaler.scale(loss).backward()
247
+ scaler.unscale_(opt)
248
+ nn.utils.clip_grad_norm_(ar.parameters(), 1.0)
249
+ nn.utils.clip_grad_norm_(nat.parameters(), 1.0)
250
+ scaler.step(opt); scaler.update()
251
+
252
+ tot += loss.item(); steps += 1
253
+ loop.set_postfix(loss=f"{loss.item():.3f}",
254
+ avg=f"{tot/steps:.3f}", refresh=False)
255
+ sched.step()
256
+
257
+ if ep == 1 or ep % SAVE_EVERY == 0 or ep == a.epochs:
258
+ torch.save(nat.state_dict(), CKDIR / f"nat_ep{ep:03d}.pt")
259
+ torch.save(ar.state_dict(), CKDIR / f"ar_ep{ep:03d}.pt")
260
+ print(f"Epoch {ep}: checkpoints written.")
261
+ print(f"Epoch {ep}: avg loss {tot/max(steps,1):.4f}")
262
+
263
+ # ───────────── inference helpers ─────────────
264
+ @torch.no_grad()
265
+ def nat_infer(ckpt, prompt, max_new, passes, streams, preset):
266
+ nat = make_nat(PRESETS[preset]).to(DEV)
267
+ nat.load_state_dict(torch.load(ckpt, map_location=DEV)); nat.eval()
268
+ gen = ParScale(NATWrap(nat), P=streams).to(DEV)
269
+ inp = torch.tensor([tok.encode(prompt) + [BLANK_ID]*max_new], device=DEV)
270
+ t0 = time.time(); out = gen.generate(inp, passes=passes)[0]; dt = time.time() - t0
271
+ txt = tok.decode([t for t in out.tolist() if t != BLANK_ID], skip_special_tokens=True)
272
+ print(txt); print(f"[{len(txt.split()) - len(prompt.split())} new tokens in {dt:.2f}s]")
273
+
274
+ @torch.no_grad()
275
+ def ar_infer(ckpt, prompt, max_new, preset):
276
+ ar = make_ar(PRESETS[preset]).to(DEV)
277
+ ar.load_state_dict(torch.load(ckpt, map_location=DEV)); ar.eval()
278
+ ids = torch.tensor([tok.encode(prompt)], device=DEV); t0 = time.time()
279
+ for _ in range(max_new):
280
+ next_id = fwd(ar, ids, causal=True)[:, -1].argmax(-1, keepdim=True)
281
+ ids = torch.cat([ids, next_id], 1)
282
+ dt = time.time() - t0
283
+ txt = tok.decode(ids[0].tolist(), skip_special_tokens=True)
284
+ print(txt); print(f"[{len(txt.split()) - len(prompt.split())} new tokens in {dt:.2f}s]")
285
+
286
+ # ───────────── CLI ─────────────
287
+ def main():
288
+ p = argparse.ArgumentParser()
289
+ sub = p.add_subparsers(dest="cmd", required=True)
290
+
291
+ tr = sub.add_parser("train")
292
+ tr.add_argument("--preset", choices=PRESETS.keys(), default="small")
293
+ tr.add_argument("--epochs", type=int, default=128)
294
+ tr.add_argument("--batch", type=int, default=2)
295
+ tr.add_argument("--max_tokens", type=int, default=100_000_000)
296
+ tr.add_argument("--amp", action="store_true")
297
+
298
+ inf = sub.add_parser("infer")
299
+ inf.add_argument("--preset", choices=PRESETS.keys(), default="small")
300
+ inf.add_argument("--mode", choices=["nat","ar"], required=True)
301
+ inf.add_argument("--prompt", required=True)
302
+ inf.add_argument("--max_new", type=int, default=120)
303
+ inf.add_argument("--ckpt", required=True)
304
+ inf.add_argument("--passes", type=int, default=1)
305
+ inf.add_argument("--streams", type=int, default=5)
306
+
307
+ args = p.parse_args()
308
+ if args.cmd == "train":
309
+ train_joint(args)
310
+ else:
311
+ if args.mode == "nat":
312
+ nat_infer(args.ckpt, args.prompt, args.max_new,
313
+ args.passes, args.streams, args.preset)
314
+ else:
315
+ ar_infer(args.ckpt, args.prompt, args.max_new, args.preset)
316
+
317
+ if __name__ == "__main__":
318
+ main()