Upload a3.py
Browse filesinference code (fix the nat inferencing using chatbot but ar works tho)
a3.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
a2.py – joint-train low-rank AR + NAT, auto-resume at epoch 84
|
| 4 |
+
|
| 5 |
+
• Loads ar_ep084.pt & nat_ep084.pt from ckpts1/ if present, then trains
|
| 6 |
+
from epoch 85. Otherwise starts from scratch.
|
| 7 |
+
• Dataset: WikiText-103 (raw) streamed, default cap = 100 M tokens.
|
| 8 |
+
• Checkpoints: epoch 1, every 5 epochs, and final.
|
| 9 |
+
• Default preset = small (fits 11 GB GPUs).
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
import argparse, math, pathlib, time
|
| 14 |
+
from contextlib import nullcontext
|
| 15 |
+
|
| 16 |
+
import torch, torch.nn as nn
|
| 17 |
+
from torch.utils.data import DataLoader, IterableDataset
|
| 18 |
+
from datasets import load_dataset
|
| 19 |
+
from tqdm.auto import tqdm
|
| 20 |
+
from transformers import AutoTokenizer, logging as hf_log
|
| 21 |
+
|
| 22 |
+
# ╭─ AMP shim ─╮
|
| 23 |
+
try:
|
| 24 |
+
from torch.amp import autocast as _ac_new
|
| 25 |
+
from torch.amp import GradScaler
|
| 26 |
+
_AMP = "new"
|
| 27 |
+
except ImportError: # torch < 2.2
|
| 28 |
+
from torch.cuda.amp import autocast as _ac_old
|
| 29 |
+
from torch.cuda.amp import GradScaler
|
| 30 |
+
_AMP = "old"
|
| 31 |
+
|
| 32 |
+
def amp(enabled, dtype, device="cuda"):
|
| 33 |
+
if not enabled:
|
| 34 |
+
return nullcontext()
|
| 35 |
+
return _ac_new(device_type=device, dtype=dtype) if _AMP == "new" else _ac_old(dtype=dtype)
|
| 36 |
+
# ╰─────────────╯
|
| 37 |
+
|
| 38 |
+
hf_log.set_verbosity_error()
|
| 39 |
+
torch.backends.cuda.matmul.allow_tf32 = True # free speed-up on Ampere+
|
| 40 |
+
|
| 41 |
+
# ───────────── presets ─────────────
|
| 42 |
+
PRESETS = {
|
| 43 |
+
"small": dict(ar_d=512, ar_layers=8, ar_heads=16,
|
| 44 |
+
nat_d=640, nat_layers=12, nat_heads=20),
|
| 45 |
+
"base": dict(ar_d=768, ar_layers=12, ar_heads=24,
|
| 46 |
+
nat_d=1024,nat_layers=16, nat_heads=32),
|
| 47 |
+
"large": dict(ar_d=1024, ar_layers=16, ar_heads=32,
|
| 48 |
+
nat_d=1280,nat_layers=24, nat_heads=40),
|
| 49 |
+
}
|
| 50 |
+
BLOCK = 128
|
| 51 |
+
DROP_P = 0.1
|
| 52 |
+
LR_AR = LR_NAT = 2e-4
|
| 53 |
+
ALPHA_KL = 1.0
|
| 54 |
+
CKDIR = pathlib.Path("ckpts1")
|
| 55 |
+
DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 56 |
+
SAVE_EVERY = 5
|
| 57 |
+
RESUME_EPOCH= 84 # ← hard-coded resume point
|
| 58 |
+
|
| 59 |
+
# ───────────── tokenizer ─────────────
|
| 60 |
+
tok = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-0528", use_fast=True)
|
| 61 |
+
if tok.pad_token is None:
|
| 62 |
+
tok.add_special_tokens({"pad_token": "[PAD]"})
|
| 63 |
+
BLANK_ID = tok.pad_token_id
|
| 64 |
+
VOCAB = max(tok.get_vocab().values()) + 1
|
| 65 |
+
|
| 66 |
+
# ───────────── data streaming ─────────────
|
| 67 |
+
def stream_wikitext(max_tokens=0):
|
| 68 |
+
"""Yield tokens from WikiText-103 until *max_tokens* reached (0 = no cap)."""
|
| 69 |
+
n = 0
|
| 70 |
+
for ex in load_dataset("wikitext", "wikitext-103-raw-v1",
|
| 71 |
+
split="train", streaming=True):
|
| 72 |
+
for t in tok.encode(ex["text"]):
|
| 73 |
+
yield t
|
| 74 |
+
n += 1
|
| 75 |
+
if max_tokens and n >= max_tokens:
|
| 76 |
+
return
|
| 77 |
+
|
| 78 |
+
class ARDataset(IterableDataset):
|
| 79 |
+
def __init__(self, blk, max_tokens=0):
|
| 80 |
+
self.blk, self.max = blk, max_tokens
|
| 81 |
+
def __iter__(self):
|
| 82 |
+
buf, gen = [], stream_wikitext(self.max)
|
| 83 |
+
for t in gen:
|
| 84 |
+
buf.append(t)
|
| 85 |
+
while len(buf) > self.blk:
|
| 86 |
+
yield torch.tensor(buf[:self.blk]), torch.tensor(buf[1:self.blk+1])
|
| 87 |
+
buf = buf[1:]
|
| 88 |
+
|
| 89 |
+
class NATDataset(IterableDataset):
|
| 90 |
+
def __init__(self, blk, max_tokens=0):
|
| 91 |
+
self.blk, self.max = blk, max_tokens
|
| 92 |
+
def __iter__(self):
|
| 93 |
+
buf, gen = [], stream_wikitext(self.max)
|
| 94 |
+
for t in gen:
|
| 95 |
+
buf.append(t)
|
| 96 |
+
while len(buf) >= self.blk:
|
| 97 |
+
tgt, buf = buf[:self.blk], buf[self.blk:]
|
| 98 |
+
inp = [BLANK_ID if i % 2 == 0 else tgt[i//2]
|
| 99 |
+
for i in range(self.blk * 2)]
|
| 100 |
+
yield torch.tensor(inp), torch.tensor(tgt)
|
| 101 |
+
|
| 102 |
+
# ───────────── transformer components ─────────────
|
| 103 |
+
class LowRankMHA(nn.Module):
|
| 104 |
+
def __init__(self, d, h, r):
|
| 105 |
+
super().__init__()
|
| 106 |
+
self.h, self.dk = h, d // h
|
| 107 |
+
self.q = self.k = self.v = nn.Linear(d, d, bias=False)
|
| 108 |
+
self.U = nn.Parameter(torch.randn(self.dk, r)); nn.init.orthogonal_(self.U)
|
| 109 |
+
self.proj = nn.Linear(h * r, d, bias=False)
|
| 110 |
+
self.drop = nn.Dropout(DROP_P)
|
| 111 |
+
|
| 112 |
+
def _proj(self, x):
|
| 113 |
+
B, N, _ = x.shape
|
| 114 |
+
return (x.view(B, N, self.h, self.dk).transpose(1, 2) @ self.U)
|
| 115 |
+
|
| 116 |
+
def forward(self, x, mask=None):
|
| 117 |
+
q, k, v = map(self._proj, (self.q(x), self.k(x), self.v(x)))
|
| 118 |
+
att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
|
| 119 |
+
if mask is not None:
|
| 120 |
+
att += mask
|
| 121 |
+
out = (att.softmax(-1) @ v).transpose(1, 2).reshape(x.size(0), x.size(1), -1)
|
| 122 |
+
return self.drop(self.proj(out))
|
| 123 |
+
|
| 124 |
+
class Block(nn.Module):
|
| 125 |
+
def __init__(self, d, h, dff, r):
|
| 126 |
+
super().__init__()
|
| 127 |
+
self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
|
| 128 |
+
self.mha = LowRankMHA(d, h, r)
|
| 129 |
+
self.ff = nn.Sequential(
|
| 130 |
+
nn.Linear(d, dff), nn.ReLU(), nn.Dropout(DROP_P), nn.Linear(dff, d)
|
| 131 |
+
)
|
| 132 |
+
def forward(self, x, mask=None):
|
| 133 |
+
y = self.ln1(x)
|
| 134 |
+
x = x + self.mha(y, mask)
|
| 135 |
+
return x + self.ff(self.ln2(x))
|
| 136 |
+
|
| 137 |
+
# ───────────── model builders ─────────────
|
| 138 |
+
def make_transformer(d, n_layers, n_heads, vocab, max_len=8192):
|
| 139 |
+
dff = 4 * d; low_rank = max(32, d // 16)
|
| 140 |
+
m = nn.Module()
|
| 141 |
+
m.emb = nn.Embedding(vocab, d)
|
| 142 |
+
m.pos = nn.Embedding(max_len, d)
|
| 143 |
+
m.blocks = nn.ModuleList(Block(d, n_heads, dff, low_rank)
|
| 144 |
+
for _ in range(n_layers))
|
| 145 |
+
m.ln = nn.LayerNorm(d)
|
| 146 |
+
m.out = nn.Linear(d, vocab)
|
| 147 |
+
return m
|
| 148 |
+
|
| 149 |
+
def make_ar(cfg): return make_transformer(cfg["ar_d"], cfg["ar_layers"],
|
| 150 |
+
cfg["ar_heads"], VOCAB, 4096)
|
| 151 |
+
def make_nat(cfg): return make_transformer(cfg["nat_d"], cfg["nat_layers"],
|
| 152 |
+
cfg["nat_heads"], VOCAB, 8192)
|
| 153 |
+
|
| 154 |
+
# ───────────── NAT helpers ─────────────
|
| 155 |
+
class NATWrap(nn.Module):
|
| 156 |
+
def __init__(self, core): super().__init__(); self.core = core
|
| 157 |
+
def forward(self, x): return self.core(torch.repeat_interleave(x, 2, 1))
|
| 158 |
+
|
| 159 |
+
class ParScale(nn.Module):
|
| 160 |
+
def __init__(self, nat, P): super().__init__(); self.nat,self.P = nat,P
|
| 161 |
+
@torch.no_grad()
|
| 162 |
+
def generate(self, x, passes=1):
|
| 163 |
+
for _ in range(passes):
|
| 164 |
+
logits = self.nat(x); logits[..., BLANK_ID] = -1e9
|
| 165 |
+
cand = logits.topk(self.P, -1).indices.permute(2,0,1)
|
| 166 |
+
best = (cand != BLANK_ID).float().mean(-1).argmax(0)
|
| 167 |
+
x = cand[best, torch.arange(x.size(0), device=x.device)][:, ::2]
|
| 168 |
+
return x
|
| 169 |
+
|
| 170 |
+
# ───────────── helpers ─────────────
|
| 171 |
+
def fwd(model, ids, causal=False):
|
| 172 |
+
B, N = ids.shape
|
| 173 |
+
x = model.emb(ids) + model.pos(torch.arange(N, device=ids.device))
|
| 174 |
+
mask = None
|
| 175 |
+
if causal:
|
| 176 |
+
mask = torch.triu(torch.full((1,1,N,N), float("-inf"),
|
| 177 |
+
device=ids.device), 1)
|
| 178 |
+
for blk in model.blocks:
|
| 179 |
+
x = blk(x, mask)
|
| 180 |
+
return model.out(model.ln(x))
|
| 181 |
+
|
| 182 |
+
# ───────────── training ─────────────
|
| 183 |
+
def train_joint(a):
|
| 184 |
+
cfg = PRESETS[a.preset]
|
| 185 |
+
ar_loader = DataLoader(ARDataset(BLOCK, a.max_tokens), batch_size=a.batch)
|
| 186 |
+
nat_loader = DataLoader(NATDataset(BLOCK, a.max_tokens), batch_size=a.batch)
|
| 187 |
+
ar , nat = make_ar(cfg).to(DEV), make_nat(cfg).to(DEV)
|
| 188 |
+
|
| 189 |
+
# ----- resume if we have epoch-84 weights -----
|
| 190 |
+
start_ep = 0
|
| 191 |
+
ck_ar = CKDIR / f"ar_ep{RESUME_EPOCH:03d}.pt"
|
| 192 |
+
ck_nat = CKDIR / f"nat_ep{RESUME_EPOCH:03d}.pt"
|
| 193 |
+
if ck_ar.exists() and ck_nat.exists():
|
| 194 |
+
ar.load_state_dict(torch.load(ck_ar, map_location=DEV))
|
| 195 |
+
nat.load_state_dict(torch.load(ck_nat, map_location=DEV))
|
| 196 |
+
start_ep = RESUME_EPOCH
|
| 197 |
+
print(f"Resuming from epoch {start_ep} checkpoints.")
|
| 198 |
+
|
| 199 |
+
opt = torch.optim.AdamW(
|
| 200 |
+
[{"params": ar.parameters(), "lr": LR_AR},
|
| 201 |
+
{"params": nat.parameters(), "lr": LR_NAT}]
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# >>>>>>> FIX: ensure 'initial_lr' so scheduler can resume <<<<<<<
|
| 205 |
+
for pg in opt.param_groups:
|
| 206 |
+
pg.setdefault("initial_lr", pg["lr"])
|
| 207 |
+
# ------------------------------------------------------------------
|
| 208 |
+
|
| 209 |
+
sched = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 210 |
+
opt, T_max=a.epochs, last_epoch=start_ep - 1
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
ce = nn.CrossEntropyLoss(label_smoothing=0.1)
|
| 214 |
+
ctc = nn.CTCLoss(blank=BLANK_ID, zero_infinity=True)
|
| 215 |
+
kl = nn.KLDivLoss(reduction="batchmean")
|
| 216 |
+
|
| 217 |
+
use_amp = DEV.type == "cuda" and a.amp
|
| 218 |
+
scaler = GradScaler(enabled=use_amp)
|
| 219 |
+
cast_dt = torch.bfloat16 if use_amp else torch.float32
|
| 220 |
+
CKDIR.mkdir(exist_ok=True)
|
| 221 |
+
tot_batches = None if not a.max_tokens else math.ceil(
|
| 222 |
+
math.ceil(a.max_tokens / BLOCK) / a.batch)
|
| 223 |
+
|
| 224 |
+
for ep in range(start_ep + 1, a.epochs + 1):
|
| 225 |
+
ar.train(); nat.train(); tot = steps = 0
|
| 226 |
+
loop = tqdm(zip(ar_loader, nat_loader), total=tot_batches,
|
| 227 |
+
desc=f"Epoch {ep}/{a.epochs}", unit="batch")
|
| 228 |
+
for (x_ar, y_ar), (x_nat, y_nat) in loop:
|
| 229 |
+
x_ar, y_ar, x_nat, y_nat = map(lambda t: t.to(DEV),
|
| 230 |
+
(x_ar, y_ar, x_nat, y_nat))
|
| 231 |
+
opt.zero_grad(set_to_none=True)
|
| 232 |
+
with amp(use_amp, cast_dt, DEV.type):
|
| 233 |
+
logits_ar = fwd(ar, x_ar, causal=True)
|
| 234 |
+
loss_ar = ce(logits_ar.reshape(-1, VOCAB), y_ar.reshape(-1))
|
| 235 |
+
|
| 236 |
+
logp_nat = fwd(nat, x_nat).log_softmax(-1).transpose(0, 1)
|
| 237 |
+
ilen=tlen = torch.full((x_nat.size(0),), x_nat.size(1)//2,
|
| 238 |
+
dtype=torch.long, device=DEV)
|
| 239 |
+
loss_nat = ctc(logp_nat, y_nat, ilen, tlen)
|
| 240 |
+
|
| 241 |
+
loss_kld = kl(fwd(nat, x_ar).log_softmax(-1),
|
| 242 |
+
logits_ar.softmax(-1).detach())
|
| 243 |
+
|
| 244 |
+
loss = loss_ar + loss_nat + ALPHA_KL * loss_kld
|
| 245 |
+
|
| 246 |
+
scaler.scale(loss).backward()
|
| 247 |
+
scaler.unscale_(opt)
|
| 248 |
+
nn.utils.clip_grad_norm_(ar.parameters(), 1.0)
|
| 249 |
+
nn.utils.clip_grad_norm_(nat.parameters(), 1.0)
|
| 250 |
+
scaler.step(opt); scaler.update()
|
| 251 |
+
|
| 252 |
+
tot += loss.item(); steps += 1
|
| 253 |
+
loop.set_postfix(loss=f"{loss.item():.3f}",
|
| 254 |
+
avg=f"{tot/steps:.3f}", refresh=False)
|
| 255 |
+
sched.step()
|
| 256 |
+
|
| 257 |
+
if ep == 1 or ep % SAVE_EVERY == 0 or ep == a.epochs:
|
| 258 |
+
torch.save(nat.state_dict(), CKDIR / f"nat_ep{ep:03d}.pt")
|
| 259 |
+
torch.save(ar.state_dict(), CKDIR / f"ar_ep{ep:03d}.pt")
|
| 260 |
+
print(f"Epoch {ep}: checkpoints written.")
|
| 261 |
+
print(f"Epoch {ep}: avg loss {tot/max(steps,1):.4f}")
|
| 262 |
+
|
| 263 |
+
# ───────────── inference helpers ─────────────
|
| 264 |
+
@torch.no_grad()
|
| 265 |
+
def nat_infer(ckpt, prompt, max_new, passes, streams, preset):
|
| 266 |
+
nat = make_nat(PRESETS[preset]).to(DEV)
|
| 267 |
+
nat.load_state_dict(torch.load(ckpt, map_location=DEV)); nat.eval()
|
| 268 |
+
gen = ParScale(NATWrap(nat), P=streams).to(DEV)
|
| 269 |
+
inp = torch.tensor([tok.encode(prompt) + [BLANK_ID]*max_new], device=DEV)
|
| 270 |
+
t0 = time.time(); out = gen.generate(inp, passes=passes)[0]; dt = time.time() - t0
|
| 271 |
+
txt = tok.decode([t for t in out.tolist() if t != BLANK_ID], skip_special_tokens=True)
|
| 272 |
+
print(txt); print(f"[{len(txt.split()) - len(prompt.split())} new tokens in {dt:.2f}s]")
|
| 273 |
+
|
| 274 |
+
@torch.no_grad()
|
| 275 |
+
def ar_infer(ckpt, prompt, max_new, preset):
|
| 276 |
+
ar = make_ar(PRESETS[preset]).to(DEV)
|
| 277 |
+
ar.load_state_dict(torch.load(ckpt, map_location=DEV)); ar.eval()
|
| 278 |
+
ids = torch.tensor([tok.encode(prompt)], device=DEV); t0 = time.time()
|
| 279 |
+
for _ in range(max_new):
|
| 280 |
+
next_id = fwd(ar, ids, causal=True)[:, -1].argmax(-1, keepdim=True)
|
| 281 |
+
ids = torch.cat([ids, next_id], 1)
|
| 282 |
+
dt = time.time() - t0
|
| 283 |
+
txt = tok.decode(ids[0].tolist(), skip_special_tokens=True)
|
| 284 |
+
print(txt); print(f"[{len(txt.split()) - len(prompt.split())} new tokens in {dt:.2f}s]")
|
| 285 |
+
|
| 286 |
+
# ───────────── CLI ─────────────
|
| 287 |
+
def main():
|
| 288 |
+
p = argparse.ArgumentParser()
|
| 289 |
+
sub = p.add_subparsers(dest="cmd", required=True)
|
| 290 |
+
|
| 291 |
+
tr = sub.add_parser("train")
|
| 292 |
+
tr.add_argument("--preset", choices=PRESETS.keys(), default="small")
|
| 293 |
+
tr.add_argument("--epochs", type=int, default=128)
|
| 294 |
+
tr.add_argument("--batch", type=int, default=2)
|
| 295 |
+
tr.add_argument("--max_tokens", type=int, default=100_000_000)
|
| 296 |
+
tr.add_argument("--amp", action="store_true")
|
| 297 |
+
|
| 298 |
+
inf = sub.add_parser("infer")
|
| 299 |
+
inf.add_argument("--preset", choices=PRESETS.keys(), default="small")
|
| 300 |
+
inf.add_argument("--mode", choices=["nat","ar"], required=True)
|
| 301 |
+
inf.add_argument("--prompt", required=True)
|
| 302 |
+
inf.add_argument("--max_new", type=int, default=120)
|
| 303 |
+
inf.add_argument("--ckpt", required=True)
|
| 304 |
+
inf.add_argument("--passes", type=int, default=1)
|
| 305 |
+
inf.add_argument("--streams", type=int, default=5)
|
| 306 |
+
|
| 307 |
+
args = p.parse_args()
|
| 308 |
+
if args.cmd == "train":
|
| 309 |
+
train_joint(args)
|
| 310 |
+
else:
|
| 311 |
+
if args.mode == "nat":
|
| 312 |
+
nat_infer(args.ckpt, args.prompt, args.max_new,
|
| 313 |
+
args.passes, args.streams, args.preset)
|
| 314 |
+
else:
|
| 315 |
+
ar_infer(args.ckpt, args.prompt, args.max_new, args.preset)
|
| 316 |
+
|
| 317 |
+
if __name__ == "__main__":
|
| 318 |
+
main()
|