ultron / train_ultron.py
trojan0x's picture
Update train_ultron.py
723c0d8 verified
#!/usr/bin/env python3
"""
Ultron Pretraining on FineWeb-Edu β€” HF Jobs Compatible
Two experiments:
1. Ultron-small baseline (dense FFN, GQA) β€” the proven config
2. Ultron-small MoE (experimental MoE in recurrent block)
Based on Parcae training recipe:
- AdamW (Ξ²1=0.9, Ξ²2=0.95), weight decay 0.1
- Cosine LR decay with linear warmup
- Per-sequence depth sampling
- bf16 mixed precision
- Gradient checkpointing for memory efficiency
Usage:
python train_ultron.py --experiment baseline --hub_model_id trojan0x/ultron-small-baseline
python train_ultron.py --experiment moe --hub_model_id trojan0x/ultron-small-moe
"""
import os
import sys
import math
import time
import json
import argparse
from dataclasses import dataclass, asdict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import IterableDataset, DataLoader
import trackio
from datasets import load_dataset
from transformers import AutoTokenizer
from huggingface_hub import HfApi
# ── Ultron model β€” download from Hub ──────────────────────────────
def setup_ultron():
"""Download ultron package from HF Hub."""
from huggingface_hub import snapshot_download
repo_path = snapshot_download("trojan0x/ultron", allow_patterns=["ultron/*.py"])
sys.path.insert(0, repo_path)
print(f"Ultron package loaded from: {repo_path}")
setup_ultron()
from ultron.model import Ultron, UltronConfig
# ===========================================================================
# Dataset: FineWeb-Edu packed streaming
# ===========================================================================
class FineWebPackedDataset(IterableDataset):
"""Streams FineWeb-Edu, tokenizes, and packs into fixed-length chunks."""
def __init__(self, tokenizer, seq_len=1024, config="sample-10BT", seed=42):
self.tokenizer = tokenizer
self.seq_len = seq_len
self.config = config
self.seed = seed
def __iter__(self):
ds = load_dataset(
"HuggingFaceFW/fineweb-edu",
name=self.config,
split="train",
streaming=True,
)
ds = ds.shuffle(seed=self.seed, buffer_size=10_000)
buffer = []
eos = self.tokenizer.eos_token_id
for sample in ds:
text = sample.get("text", "")
if not text or len(text) < 50:
continue
tokens = self.tokenizer.encode(text, add_special_tokens=False)
tokens.append(eos)
buffer.extend(tokens)
while len(buffer) >= self.seq_len + 1:
chunk = buffer[:self.seq_len + 1]
buffer = buffer[self.seq_len:]
yield {
"input_ids": torch.tensor(chunk[:-1], dtype=torch.long),
"labels": torch.tensor(chunk[1:], dtype=torch.long),
}
# ===========================================================================
# Training utilities
# ===========================================================================
def get_lr(step, warmup_steps, max_steps, max_lr, min_lr):
"""Linear warmup + cosine decay."""
if step < warmup_steps:
return max_lr * (step + 1) / warmup_steps
if step >= max_steps:
return min_lr
progress = (step - warmup_steps) / max(1, max_steps - warmup_steps)
return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
def sample_loop_depth(mu_rec, batch_size):
"""Per-sequence depth sampling (Parcae).
Each sequence gets a different loop depth from a geometric distribution.
Returns the mean depth for the batch (simplification for efficiency).
"""
depths = []
for _ in range(batch_size):
d = max(1, min(2 * mu_rec, int(torch.distributions.Geometric(
probs=1.0 / max(mu_rec, 1)
).sample().item()) + 1))
depths.append(d)
return max(1, sum(depths) // len(depths))
# ===========================================================================
# Main training function
# ===========================================================================
def train(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_bf16 = device.type == "cuda" and torch.cuda.is_bf16_supported()
dtype = torch.bfloat16 if use_bf16 else torch.float32
print(f"Device: {device} | dtype: {dtype}")
# ── Model config ──────────────────────────────────────────────
if args.experiment == "baseline":
cfg = UltronConfig(
vocab_size=50257, # GPT-2 vocab
dim=768,
n_heads=12,
n_kv_heads=4,
max_seq_len=args.seq_len,
prelude_layers=2,
coda_layers=2,
recurrent_layers=4,
max_loop_iters=8,
attn_type="gqa",
use_moe=False,
lora_rank=8,
act_threshold=0.99,
gradient_checkpointing=True,
dropout=0.0,
)
run_name = "ultron-small-baseline"
elif args.experiment == "moe":
cfg = UltronConfig(
vocab_size=50257,
dim=768,
n_heads=12,
n_kv_heads=4,
max_seq_len=args.seq_len,
prelude_layers=2,
coda_layers=2,
recurrent_layers=4,
max_loop_iters=8,
attn_type="gqa",
use_moe=True,
n_experts=8,
n_shared_experts=1,
n_experts_per_tok=2,
expert_dim=384,
lora_rank=8,
act_threshold=0.99,
gradient_checkpointing=True,
dropout=0.0,
)
run_name = "ultron-small-moe"
else:
raise ValueError(f"Unknown experiment: {args.experiment}")
# ── Build model ───────────────────────────────────────────────
model = Ultron(cfg).to(device)
total_params = model.get_num_params(non_embedding=False)
non_emb_params = model.get_num_params(non_embedding=True)
print(f"\n{'='*60}")
print(f"Ultron [{args.experiment}]")
print(f" Total params: {total_params:,}")
print(f" Non-emb params: {non_emb_params:,}")
print(f" ρ(A): {model.get_spectral_radius():.6f}")
print(f" Config: {json.dumps(asdict(cfg), indent=2, default=str)}")
print(f"{'='*60}\n")
# ── Tokenizer ─────────────────────────────────────────────────
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
# ── Dataset ───────────────────────────────────────────────────
dataset = FineWebPackedDataset(
tokenizer=tokenizer,
seq_len=args.seq_len,
config=args.dataset_config,
)
loader = DataLoader(
dataset,
batch_size=args.batch_size,
num_workers=2,
pin_memory=True,
prefetch_factor=4,
)
# ── Optimizer ─────────────────────────────────────────────────
optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.lr,
betas=(0.9, 0.95),
eps=1e-8,
weight_decay=0.1,
)
# ── Trackio ───────────────────────────────────────────────────
trackio_space = os.environ.get("TRACKIO_SPACE_ID", args.trackio_space)
if trackio_space:
trackio.init(
project="ultron-pretraining",
name=run_name,
space_id=trackio_space,
config={
"experiment": args.experiment,
"total_params": total_params,
"seq_len": args.seq_len,
"batch_size": args.batch_size,
"grad_accum": args.grad_accum,
"lr": args.lr,
"max_steps": args.max_steps,
"use_moe": cfg.use_moe,
"loop_iters": cfg.max_loop_iters,
"recurrent_layers": cfg.recurrent_layers,
},
auto_log_gpu=True,
gpu_log_interval=30.0,
)
print(f"Trackio initialized: {trackio_space}")
else:
print("Trackio: no space_id set, logging to stdout only")
# ── Training loop ─────────────────────────────────────────────
model.train()
step = 0
tokens_seen = 0
running_loss = 0.0
best_loss = float("inf")
t0 = time.time()
log_t0 = time.time()
effective_batch = args.batch_size * args.grad_accum
print(f"\nTraining for {args.max_steps} steps")
print(f" Batch size: {args.batch_size} Γ— {args.grad_accum} accum = {effective_batch}")
print(f" Sequence length: {args.seq_len}")
print(f" Tokens per step: {effective_batch * args.seq_len:,}")
print(f" bf16: {use_bf16}")
print(f" Gradient checkpointing: {cfg.gradient_checkpointing}")
print()
optimizer.zero_grad()
for batch in loader:
if step >= args.max_steps:
break
input_ids = batch["input_ids"].to(device)
labels = batch["labels"].to(device)
# LR schedule
lr = get_lr(step, args.warmup_steps, args.max_steps, args.lr, args.min_lr)
for g in optimizer.param_groups:
g["lr"] = lr
# Per-sequence depth sampling (Parcae)
n_loops = sample_loop_depth(cfg.max_loop_iters, input_ids.shape[0])
# Forward + loss
with torch.autocast(device_type="cuda", dtype=dtype, enabled=use_bf16):
logits = model(input_ids, n_loops=n_loops)
loss = F.cross_entropy(
logits.view(-1, cfg.vocab_size),
labels.view(-1),
)
loss_scaled = loss / args.grad_accum
# Backward
loss_scaled.backward()
running_loss += loss.item()
tokens_seen += input_ids.numel()
# Gradient accumulation step
if (step + 1) % args.grad_accum == 0:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
step += 1
# ── Logging ───────────────────────────────────────────────
if step % args.log_interval == 0:
avg_loss = running_loss / args.log_interval
ppl = math.exp(min(avg_loss, 20))
rho = model.get_spectral_radius()
dt = time.time() - log_t0
tok_per_sec = (args.log_interval * input_ids.numel()) / max(dt, 1e-6)
elapsed = time.time() - t0
print(f"step {step:>6d}/{args.max_steps} | loss {avg_loss:.4f} | ppl {ppl:.1f} | "
f"lr {lr:.2e} | ρ(A) {rho:.4f} | depth {n_loops} | "
f"tok/s {tok_per_sec:,.0f} | {elapsed:.0f}s")
if trackio_space:
trackio.log({
"train/loss": avg_loss,
"train/perplexity": ppl,
"train/lr": lr,
"train/spectral_radius": rho,
"train/loop_depth": n_loops,
"train/tokens_seen": tokens_seen,
"train/tok_per_sec": tok_per_sec,
})
running_loss = 0.0
log_t0 = time.time()
# ── Save checkpoint ───────────────────────────────────────
if step % args.save_interval == 0 and step > 0:
ckpt = {
"step": step,
"tokens_seen": tokens_seen,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"config": asdict(cfg),
"loss": avg_loss if step >= args.log_interval else float("inf"),
}
ckpt_path = f"ultron_ckpt_step{step}.pt"
torch.save(ckpt, ckpt_path)
print(f" Saved checkpoint: {ckpt_path}")
# Push to hub
if args.hub_model_id:
try:
api = HfApi()
api.upload_file(
path_or_fileobj=ckpt_path,
path_in_repo=f"checkpoints/{ckpt_path}",
repo_id=args.hub_model_id,
)
print(f" Pushed to {args.hub_model_id}")
except Exception as e:
print(f" Hub push failed: {e}")
# Clean up local file to save space
if os.path.exists(ckpt_path):
os.remove(ckpt_path)
# ── Final save ────────────────────────────────────────────────
elapsed = time.time() - t0
final_loss = running_loss / max(step % args.log_interval, 1)
print(f"\nTraining complete! {step} steps in {elapsed:.0f}s ({elapsed/3600:.1f}h)")
print(f"Final loss: {final_loss:.4f}")
print(f"Final ρ(A): {model.get_spectral_radius():.6f}")
print(f"Tokens seen: {tokens_seen:,}")
# Save final model
final = {
"step": step,
"tokens_seen": tokens_seen,
"model_state_dict": model.state_dict(),
"config": asdict(cfg),
}
final_path = "ultron_final.pt"
torch.save(final, final_path)
if args.hub_model_id:
try:
api = HfApi()
api.upload_file(
path_or_fileobj=final_path,
path_in_repo="ultron_final.pt",
repo_id=args.hub_model_id,
)
# Also upload config
config_path = "config.json"
with open(config_path, "w") as f:
json.dump(asdict(cfg), f, indent=2, default=str)
api.upload_file(
path_or_fileobj=config_path,
path_in_repo="config.json",
repo_id=args.hub_model_id,
)
print(f"Final model pushed to {args.hub_model_id}")
except Exception as e:
print(f"Final push failed: {e}")
if trackio_space:
trackio.finish()
print("Done!")
# ===========================================================================
# CLI
# ===========================================================================
def main():
parser = argparse.ArgumentParser(description="Ultron Pretraining")
parser.add_argument("--experiment", type=str,
default=os.environ.get("EXPERIMENT", "baseline"),
choices=["baseline", "moe"])
parser.add_argument("--dataset_config", type=str, default="sample-10BT")
parser.add_argument("--seq_len", type=int, default=1024)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--grad_accum", type=int, default=8)
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--min_lr", type=float, default=3e-5)
parser.add_argument("--warmup_steps", type=int, default=1000)
parser.add_argument("--max_steps", type=int, default=10000)
parser.add_argument("--log_interval", type=int, default=10)
parser.add_argument("--save_interval", type=int, default=2000)
parser.add_argument("--hub_model_id", type=str,
default=os.environ.get("HUB_MODEL_ID", None))
parser.add_argument("--trackio_space", type=str,
default=os.environ.get("TRACKIO_SPACE_ID", None))
args = parser.parse_args()
train(args)
if __name__ == "__main__":
main()