trojan0x commited on
Commit
a8346e3
Β·
verified Β·
1 Parent(s): 867c386

Add train_ultron.py

Browse files
Files changed (1) hide show
  1. train_ultron.py +434 -0
train_ultron.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Ultron Pretraining on FineWeb-Edu β€” HF Jobs Compatible
4
+
5
+ Two experiments:
6
+ 1. Ultron-small baseline (dense FFN, GQA) β€” the proven config
7
+ 2. Ultron-small MoE (experimental MoE in recurrent block)
8
+
9
+ Based on Parcae training recipe:
10
+ - AdamW (Ξ²1=0.9, Ξ²2=0.95), weight decay 0.1
11
+ - Cosine LR decay with linear warmup
12
+ - Per-sequence depth sampling
13
+ - bf16 mixed precision
14
+ - Gradient checkpointing for memory efficiency
15
+
16
+ Usage:
17
+ python train_ultron.py --experiment baseline --hub_model_id trojan0x/ultron-small-baseline
18
+ python train_ultron.py --experiment moe --hub_model_id trojan0x/ultron-small-moe
19
+ """
20
+
21
+ import os
22
+ import sys
23
+ import math
24
+ import time
25
+ import json
26
+ import argparse
27
+ from dataclasses import dataclass, asdict
28
+
29
+ import torch
30
+ import torch.nn as nn
31
+ import torch.nn.functional as F
32
+ from torch.utils.data import IterableDataset, DataLoader
33
+
34
+ # ── Install deps ──────────────────────────────────────────────────
35
+ def ensure_deps():
36
+ import subprocess
37
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "-q",
38
+ "datasets", "transformers", "huggingface_hub", "trackio", "lm-eval"])
39
+ ensure_deps()
40
+
41
+ import trackio
42
+ from datasets import load_dataset
43
+ from transformers import AutoTokenizer
44
+ from huggingface_hub import HfApi
45
+
46
+ # ── Ultron model (inline for self-contained job) ──────────────────
47
+ # We import from the repo files uploaded to HF Hub
48
+ # For the job, we'll include the model code inline
49
+
50
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
51
+
52
+ # Import model β€” the ultron/ package should be alongside this script
53
+ from ultron.model import Ultron, UltronConfig
54
+
55
+
56
+ # ===========================================================================
57
+ # Dataset: FineWeb-Edu packed streaming
58
+ # ===========================================================================
59
+
60
+ class FineWebPackedDataset(IterableDataset):
61
+ """Streams FineWeb-Edu, tokenizes, and packs into fixed-length chunks."""
62
+
63
+ def __init__(self, tokenizer, seq_len=1024, config="sample-10BT", seed=42):
64
+ self.tokenizer = tokenizer
65
+ self.seq_len = seq_len
66
+ self.config = config
67
+ self.seed = seed
68
+
69
+ def __iter__(self):
70
+ ds = load_dataset(
71
+ "HuggingFaceFW/fineweb-edu",
72
+ name=self.config,
73
+ split="train",
74
+ streaming=True,
75
+ )
76
+ ds = ds.shuffle(seed=self.seed, buffer_size=10_000)
77
+
78
+ buffer = []
79
+ eos = self.tokenizer.eos_token_id
80
+
81
+ for sample in ds:
82
+ text = sample.get("text", "")
83
+ if not text or len(text) < 50:
84
+ continue
85
+ tokens = self.tokenizer.encode(text, add_special_tokens=False)
86
+ tokens.append(eos)
87
+ buffer.extend(tokens)
88
+
89
+ while len(buffer) >= self.seq_len + 1:
90
+ chunk = buffer[:self.seq_len + 1]
91
+ buffer = buffer[self.seq_len:]
92
+ yield {
93
+ "input_ids": torch.tensor(chunk[:-1], dtype=torch.long),
94
+ "labels": torch.tensor(chunk[1:], dtype=torch.long),
95
+ }
96
+
97
+
98
+ # ===========================================================================
99
+ # Training utilities
100
+ # ===========================================================================
101
+
102
+ def get_lr(step, warmup_steps, max_steps, max_lr, min_lr):
103
+ """Linear warmup + cosine decay."""
104
+ if step < warmup_steps:
105
+ return max_lr * (step + 1) / warmup_steps
106
+ if step >= max_steps:
107
+ return min_lr
108
+ progress = (step - warmup_steps) / max(1, max_steps - warmup_steps)
109
+ return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
110
+
111
+
112
+ def sample_loop_depth(mu_rec, batch_size):
113
+ """Per-sequence depth sampling (Parcae).
114
+ Each sequence gets a different loop depth from a geometric distribution.
115
+ Returns the mean depth for the batch (simplification for efficiency).
116
+ """
117
+ depths = []
118
+ for _ in range(batch_size):
119
+ d = max(1, min(2 * mu_rec, int(torch.distributions.Geometric(
120
+ probs=1.0 / max(mu_rec, 1)
121
+ ).sample().item()) + 1))
122
+ depths.append(d)
123
+ return max(1, sum(depths) // len(depths))
124
+
125
+
126
+ # ===========================================================================
127
+ # Main training function
128
+ # ===========================================================================
129
+
130
+ def train(args):
131
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
132
+ use_bf16 = device.type == "cuda" and torch.cuda.is_bf16_supported()
133
+ dtype = torch.bfloat16 if use_bf16 else torch.float32
134
+
135
+ print(f"Device: {device} | dtype: {dtype}")
136
+
137
+ # ── Model config ──────────────────────────────────────────────
138
+ if args.experiment == "baseline":
139
+ cfg = UltronConfig(
140
+ vocab_size=50257, # GPT-2 vocab
141
+ dim=768,
142
+ n_heads=12,
143
+ n_kv_heads=4,
144
+ max_seq_len=args.seq_len,
145
+ prelude_layers=2,
146
+ coda_layers=2,
147
+ recurrent_layers=4,
148
+ max_loop_iters=8,
149
+ attn_type="gqa",
150
+ use_moe=False,
151
+ lora_rank=8,
152
+ act_threshold=0.99,
153
+ gradient_checkpointing=True,
154
+ dropout=0.0,
155
+ )
156
+ run_name = "ultron-small-baseline"
157
+ elif args.experiment == "moe":
158
+ cfg = UltronConfig(
159
+ vocab_size=50257,
160
+ dim=768,
161
+ n_heads=12,
162
+ n_kv_heads=4,
163
+ max_seq_len=args.seq_len,
164
+ prelude_layers=2,
165
+ coda_layers=2,
166
+ recurrent_layers=4,
167
+ max_loop_iters=8,
168
+ attn_type="gqa",
169
+ use_moe=True,
170
+ n_experts=8,
171
+ n_shared_experts=1,
172
+ n_experts_per_tok=2,
173
+ expert_dim=384,
174
+ lora_rank=8,
175
+ act_threshold=0.99,
176
+ gradient_checkpointing=True,
177
+ dropout=0.0,
178
+ )
179
+ run_name = "ultron-small-moe"
180
+ else:
181
+ raise ValueError(f"Unknown experiment: {args.experiment}")
182
+
183
+ # ── Build model ───────────────────────────────────────────────
184
+ model = Ultron(cfg).to(device)
185
+ total_params = model.get_num_params(non_embedding=False)
186
+ non_emb_params = model.get_num_params(non_embedding=True)
187
+ print(f"\n{'='*60}")
188
+ print(f"Ultron [{args.experiment}]")
189
+ print(f" Total params: {total_params:,}")
190
+ print(f" Non-emb params: {non_emb_params:,}")
191
+ print(f" ρ(A): {model.get_spectral_radius():.6f}")
192
+ print(f" Config: {json.dumps(asdict(cfg), indent=2, default=str)}")
193
+ print(f"{'='*60}\n")
194
+
195
+ # ── Tokenizer ─────────────────────────────────────────────────
196
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
197
+ tokenizer.pad_token = tokenizer.eos_token
198
+
199
+ # ── Dataset ───────────────────────────────────────────────────
200
+ dataset = FineWebPackedDataset(
201
+ tokenizer=tokenizer,
202
+ seq_len=args.seq_len,
203
+ config=args.dataset_config,
204
+ )
205
+ loader = DataLoader(
206
+ dataset,
207
+ batch_size=args.batch_size,
208
+ num_workers=2,
209
+ pin_memory=True,
210
+ prefetch_factor=4,
211
+ )
212
+
213
+ # ── Optimizer ─────────────────────────────────────────────────
214
+ optimizer = torch.optim.AdamW(
215
+ model.parameters(),
216
+ lr=args.lr,
217
+ betas=(0.9, 0.95),
218
+ eps=1e-8,
219
+ weight_decay=0.1,
220
+ )
221
+
222
+ # ── Trackio ───────────────────────────────────────────────────
223
+ trackio_space = os.environ.get("TRACKIO_SPACE_ID", args.trackio_space)
224
+ if trackio_space:
225
+ trackio.init(
226
+ project="ultron-pretraining",
227
+ name=run_name,
228
+ space_id=trackio_space,
229
+ config={
230
+ "experiment": args.experiment,
231
+ "total_params": total_params,
232
+ "seq_len": args.seq_len,
233
+ "batch_size": args.batch_size,
234
+ "grad_accum": args.grad_accum,
235
+ "lr": args.lr,
236
+ "max_steps": args.max_steps,
237
+ "use_moe": cfg.use_moe,
238
+ "loop_iters": cfg.max_loop_iters,
239
+ "recurrent_layers": cfg.recurrent_layers,
240
+ },
241
+ auto_log_gpu=True,
242
+ gpu_log_interval=30.0,
243
+ )
244
+ print(f"Trackio initialized: {trackio_space}")
245
+ else:
246
+ print("Trackio: no space_id set, logging to stdout only")
247
+
248
+ # ── Training loop ─────────────────────────────────────────────
249
+ model.train()
250
+ step = 0
251
+ tokens_seen = 0
252
+ running_loss = 0.0
253
+ best_loss = float("inf")
254
+ t0 = time.time()
255
+ log_t0 = time.time()
256
+
257
+ effective_batch = args.batch_size * args.grad_accum
258
+ print(f"\nTraining for {args.max_steps} steps")
259
+ print(f" Batch size: {args.batch_size} Γ— {args.grad_accum} accum = {effective_batch}")
260
+ print(f" Sequence length: {args.seq_len}")
261
+ print(f" Tokens per step: {effective_batch * args.seq_len:,}")
262
+ print(f" bf16: {use_bf16}")
263
+ print(f" Gradient checkpointing: {cfg.gradient_checkpointing}")
264
+ print()
265
+
266
+ optimizer.zero_grad()
267
+
268
+ for batch in loader:
269
+ if step >= args.max_steps:
270
+ break
271
+
272
+ input_ids = batch["input_ids"].to(device)
273
+ labels = batch["labels"].to(device)
274
+
275
+ # LR schedule
276
+ lr = get_lr(step, args.warmup_steps, args.max_steps, args.lr, args.min_lr)
277
+ for g in optimizer.param_groups:
278
+ g["lr"] = lr
279
+
280
+ # Per-sequence depth sampling (Parcae)
281
+ n_loops = sample_loop_depth(cfg.max_loop_iters, input_ids.shape[0])
282
+
283
+ # Forward + loss
284
+ with torch.autocast(device_type="cuda", dtype=dtype, enabled=use_bf16):
285
+ logits = model(input_ids, n_loops=n_loops)
286
+ loss = F.cross_entropy(
287
+ logits.view(-1, cfg.vocab_size),
288
+ labels.view(-1),
289
+ )
290
+ loss_scaled = loss / args.grad_accum
291
+
292
+ # Backward
293
+ loss_scaled.backward()
294
+
295
+ running_loss += loss.item()
296
+ tokens_seen += input_ids.numel()
297
+
298
+ # Gradient accumulation step
299
+ if (step + 1) % args.grad_accum == 0:
300
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
301
+ optimizer.step()
302
+ optimizer.zero_grad()
303
+
304
+ step += 1
305
+
306
+ # ── Logging ───────────────────────────────────────────────
307
+ if step % args.log_interval == 0:
308
+ avg_loss = running_loss / args.log_interval
309
+ ppl = math.exp(min(avg_loss, 20))
310
+ rho = model.get_spectral_radius()
311
+ dt = time.time() - log_t0
312
+ tok_per_sec = (args.log_interval * input_ids.numel()) / max(dt, 1e-6)
313
+ elapsed = time.time() - t0
314
+
315
+ print(f"step {step:>6d}/{args.max_steps} | loss {avg_loss:.4f} | ppl {ppl:.1f} | "
316
+ f"lr {lr:.2e} | ρ(A) {rho:.4f} | depth {n_loops} | "
317
+ f"tok/s {tok_per_sec:,.0f} | {elapsed:.0f}s")
318
+
319
+ if trackio_space:
320
+ trackio.log({
321
+ "train/loss": avg_loss,
322
+ "train/perplexity": ppl,
323
+ "train/lr": lr,
324
+ "train/spectral_radius": rho,
325
+ "train/loop_depth": n_loops,
326
+ "train/tokens_seen": tokens_seen,
327
+ "train/tok_per_sec": tok_per_sec,
328
+ })
329
+
330
+ running_loss = 0.0
331
+ log_t0 = time.time()
332
+
333
+ # ── Save checkpoint ───────────────────────────────────────
334
+ if step % args.save_interval == 0 and step > 0:
335
+ ckpt = {
336
+ "step": step,
337
+ "tokens_seen": tokens_seen,
338
+ "model_state_dict": model.state_dict(),
339
+ "optimizer_state_dict": optimizer.state_dict(),
340
+ "config": asdict(cfg),
341
+ "loss": avg_loss if step >= args.log_interval else float("inf"),
342
+ }
343
+ ckpt_path = f"ultron_ckpt_step{step}.pt"
344
+ torch.save(ckpt, ckpt_path)
345
+ print(f" Saved checkpoint: {ckpt_path}")
346
+
347
+ # Push to hub
348
+ if args.hub_model_id:
349
+ try:
350
+ api = HfApi()
351
+ api.upload_file(
352
+ path_or_fileobj=ckpt_path,
353
+ path_in_repo=f"checkpoints/{ckpt_path}",
354
+ repo_id=args.hub_model_id,
355
+ )
356
+ print(f" Pushed to {args.hub_model_id}")
357
+ except Exception as e:
358
+ print(f" Hub push failed: {e}")
359
+
360
+ # Clean up local file to save space
361
+ if os.path.exists(ckpt_path):
362
+ os.remove(ckpt_path)
363
+
364
+ # ── Final save ────────────────────────────────────────────────
365
+ elapsed = time.time() - t0
366
+ final_loss = running_loss / max(step % args.log_interval, 1)
367
+ print(f"\nTraining complete! {step} steps in {elapsed:.0f}s ({elapsed/3600:.1f}h)")
368
+ print(f"Final loss: {final_loss:.4f}")
369
+ print(f"Final ρ(A): {model.get_spectral_radius():.6f}")
370
+ print(f"Tokens seen: {tokens_seen:,}")
371
+
372
+ # Save final model
373
+ final = {
374
+ "step": step,
375
+ "tokens_seen": tokens_seen,
376
+ "model_state_dict": model.state_dict(),
377
+ "config": asdict(cfg),
378
+ }
379
+ final_path = "ultron_final.pt"
380
+ torch.save(final, final_path)
381
+
382
+ if args.hub_model_id:
383
+ try:
384
+ api = HfApi()
385
+ api.upload_file(
386
+ path_or_fileobj=final_path,
387
+ path_in_repo="ultron_final.pt",
388
+ repo_id=args.hub_model_id,
389
+ )
390
+ # Also upload config
391
+ config_path = "config.json"
392
+ with open(config_path, "w") as f:
393
+ json.dump(asdict(cfg), f, indent=2, default=str)
394
+ api.upload_file(
395
+ path_or_fileobj=config_path,
396
+ path_in_repo="config.json",
397
+ repo_id=args.hub_model_id,
398
+ )
399
+ print(f"Final model pushed to {args.hub_model_id}")
400
+ except Exception as e:
401
+ print(f"Final push failed: {e}")
402
+
403
+ if trackio_space:
404
+ trackio.finish()
405
+
406
+ print("Done!")
407
+
408
+
409
+ # ===========================================================================
410
+ # CLI
411
+ # ===========================================================================
412
+
413
+ def main():
414
+ parser = argparse.ArgumentParser(description="Ultron Pretraining")
415
+ parser.add_argument("--experiment", type=str, default="baseline",
416
+ choices=["baseline", "moe"])
417
+ parser.add_argument("--dataset_config", type=str, default="sample-10BT")
418
+ parser.add_argument("--seq_len", type=int, default=1024)
419
+ parser.add_argument("--batch_size", type=int, default=8)
420
+ parser.add_argument("--grad_accum", type=int, default=8)
421
+ parser.add_argument("--lr", type=float, default=3e-4)
422
+ parser.add_argument("--min_lr", type=float, default=3e-5)
423
+ parser.add_argument("--warmup_steps", type=int, default=1000)
424
+ parser.add_argument("--max_steps", type=int, default=10000)
425
+ parser.add_argument("--log_interval", type=int, default=10)
426
+ parser.add_argument("--save_interval", type=int, default=2000)
427
+ parser.add_argument("--hub_model_id", type=str, default=None)
428
+ parser.add_argument("--trackio_space", type=str, default=None)
429
+ args = parser.parse_args()
430
+ train(args)
431
+
432
+
433
+ if __name__ == "__main__":
434
+ main()