#!/usr/bin/env python3 """ train_ar.py — Autoregressive (GPT-2 style) baseline training. Matches scripts/train.py in data pipeline, optimizer, and schedule; the only differences are: - Model: src.models.ar_model.ARModel (standard decoder, causal self-attn) - Loss: shifted next-token cross-entropy, pad-masked by attention_mask - No AncestorTable, no NoisyStateBuilder, no t-weighting Usage: python scripts/train_ar.py --config configs/ar_owt.yaml torchrun --nproc_per_node=8 scripts/train_ar.py --config configs/ar_owt.yaml torchrun --nproc_per_node=8 scripts/train_ar.py \\ --config configs/ar_owt.yaml \\ --resume outputs/ar_baseline/latest.pt """ import sys import os import argparse import math import time from pathlib import Path ROOT = Path(__file__).resolve().parents[1] # sad/ import torch import torch.nn as nn import torch.nn.functional as F import yaml sys.path.insert(0, str(ROOT)) from src.utils import set_seed, count_parameters from src.models.ar_model import ARModel from src.data import build_debug_dataloader, build_owt_dataloader try: from tqdm import tqdm _has_tqdm = True except ImportError: _has_tqdm = False def _unwrap(model): """Peel DDP (.module) and torch.compile (._orig_mod) wrappers down to ARModel.""" while True: if hasattr(model, "_orig_mod"): model = model._orig_mod elif hasattr(model, "module"): model = model.module else: return model def parse_args(): p = argparse.ArgumentParser() p.add_argument("--config", default="configs/ar_owt.yaml") p.add_argument("--resume", default=None, type=str) p.add_argument("--num_steps", type=int, default=None) p.add_argument("--batch_size", type=int, default=None) p.add_argument("--local_rank", type=int, default=0) return p.parse_args() def load_config(path: str) -> dict: with open(path) as f: return yaml.safe_load(f) def build_tokenizer(config: dict): """Identical to train.py so the two runs consume the exact same token stream.""" data_cfg = config.get("data", {}) dataset = data_cfg.get("dataset", "debug") vocab_size = config["model"]["vocab_size"] if dataset == "debug": class MockTokenizer: def __init__(self, vocab_size): self.vocab_size = vocab_size self.pad_token_id = 0 self.eos_token_id = 0 self.bos_token_id = 0 self.mask_token_id = vocab_size - 1 self.model_max_length = config["model"]["max_seq_len"] def __len__(self): return self.vocab_size return MockTokenizer(vocab_size) from transformers import AutoTokenizer tok = AutoTokenizer.from_pretrained( ROOT / "tokenizers" / "gpt2", local_files_only=True, ) if tok.eos_token is None: tok.add_special_tokens({"eos_token": "<|endoftext|>"}) if tok.bos_token is None: tok.bos_token = tok.eos_token if tok.pad_token is None: tok.pad_token = tok.eos_token # AR baseline does not need [MASK], but the shared OWT dataloader builds # the same token stream regardless — so no special handling required. config["model"]["vocab_size"] = len(tok) return tok def build_dataloaders(config: dict, tokenizer): data_cfg = config.get("data", {}) dataset = data_cfg.get("dataset", "debug") seq_len = data_cfg.get("seq_len", 512) batch_size = config["training"]["batch_size"] if dataset == "debug": train_loader = build_debug_dataloader( vocab_size=config["model"]["vocab_size"], seq_len=seq_len, batch_size=batch_size, num_samples=512, mask_token_id=getattr(tokenizer, "mask_token_id", 0) or 0, ) val_loader = build_debug_dataloader( vocab_size=config["model"]["vocab_size"], seq_len=seq_len, batch_size=batch_size, num_samples=64, mask_token_id=getattr(tokenizer, "mask_token_id", 0) or 0, ) elif dataset == "openwebtext": mode = data_cfg.get("mode", "subsample") train_loader = build_owt_dataloader( tokenizer, split="train[:-100000]", seq_len=seq_len, batch_size=batch_size, num_workers=data_cfg.get("num_workers", 4), cache_dir=data_cfg.get("cache_dir", None), max_samples=data_cfg.get("max_train_samples", None), mode=mode, ) val_loader = build_owt_dataloader( tokenizer, split="train[-100000:]", seq_len=seq_len, batch_size=batch_size, num_workers=2, cache_dir=data_cfg.get("cache_dir", None), max_samples=data_cfg.get("max_val_samples", 100000), mode=mode, shard_across_ranks=False, ) else: raise ValueError(f"Unknown dataset: {dataset}") return train_loader, val_loader def build_optimizer(config: dict, model: nn.Module): train_cfg = config["training"] betas = tuple(train_cfg.get("adam_betas", (0.9, 0.99))) return torch.optim.AdamW( list(model.parameters()), lr=train_cfg["lr"], weight_decay=train_cfg.get("weight_decay", 0.02), betas=betas, eps=train_cfg.get("adam_eps", 1e-9), fused=True, ) def get_lr(step: int, config: dict) -> float: """Linear warmup + cosine decay, identical to train.py.""" train_cfg = config["training"] num_steps = train_cfg["num_steps"] warmup = train_cfg.get("warmup_steps", min(2000, num_steps // 100)) lr_min = train_cfg.get("lr_min", train_cfg["lr"] * 0.1) lr_max = train_cfg["lr"] if step < warmup: return lr_max * step / max(warmup, 1) progress = (step - warmup) / max(num_steps - warmup, 1) return lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * progress)) def ar_step(batch: dict, model, dtype) -> tuple: """ Shifted next-token CE: inputs = input_ids[:, :-1] targets = input_ids[:, 1:] loss is averaged over non-pad target positions. """ input_ids = batch["input_ids"] # [B, S] attention_mask = batch["attention_mask"] # [B, S] device = input_ids.device autocast_device = "cuda" if device.type == "cuda" else "cpu" with torch.autocast(device_type=autocast_device, dtype=dtype): logits = model(input_ids=input_ids) # [B, S, V] B, S, V = logits.shape # Shift-by-one: position i predicts token i+1. logits_shift = logits[:, :-1, :].contiguous() # [B, S-1, V] targets = input_ids[:, 1:].contiguous() # [B, S-1] target_mask = attention_mask[:, 1:].float() # [B, S-1] # fp32 CE for bf16 safety (same rationale as SADLoss). ce = F.cross_entropy( logits_shift.reshape(-1, V).float(), targets.reshape(-1), reduction="none", ).reshape(B, S - 1) total_valid = target_mask.sum().clamp(min=1) loss = (ce * target_mask).sum() / total_valid metrics = { "loss_ce": loss.detach(), "loss_total": loss.detach(), "ppl": loss.detach().exp(), "valid_tokens": total_valid.detach(), } return loss, metrics def save_checkpoint(step, model, optimizer, config, save_dir: Path, metrics: dict): save_dir.mkdir(parents=True, exist_ok=True) ckpt = { "step": step, "model": model.state_dict(), "optimizer": optimizer.state_dict(), "config": config, "metrics": {k: v.item() if hasattr(v, "item") else v for k, v in metrics.items()}, } torch.save(ckpt, save_dir / f"ckpt_{step}.pt") torch.save(ckpt, save_dir / "latest.pt") print(f" Saved checkpoint: {save_dir}/ckpt_{step}.pt") @torch.no_grad() def evaluate(model, dtype, val_loader, device, num_batches: int = 50) -> dict: model.eval() totals: dict = {} count = 0 for batch in val_loader: if count >= num_batches: break batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} _, metrics = ar_step(batch, model, dtype) for k, v in metrics.items(): val = v.item() if hasattr(v, "item") else float(v) totals[k] = totals.get(k, 0.0) + val count += 1 model.train() return {k: v / max(count, 1) for k, v in totals.items()} def fmt_metric(v) -> str: v = v.item() if hasattr(v, "item") else float(v) return f"{v:.4f}" def main(): args = parse_args() config = load_config(args.config) if args.num_steps is not None: config["training"]["num_steps"] = args.num_steps if args.batch_size is not None: config["training"]["batch_size"] = args.batch_size local_rank = int(os.environ.get("LOCAL_RANK", args.local_rank)) world_size = int(os.environ.get("WORLD_SIZE", 1)) is_main = (local_rank == 0) if world_size > 1: import torch.distributed as dist dist.init_process_group("nccl") device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(device) elif torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") if is_main: print(f"Device: {device} world_size: {world_size}") train_cfg = config["training"] set_seed(train_cfg.get("seed", 42) + local_rank) dtype_str = train_cfg.get("dtype", "bf16") dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[dtype_str] tokenizer = build_tokenizer(config) model_cfg = config["model"] model = ARModel( vocab_size=model_cfg["vocab_size"], hidden_size=model_cfg["hidden_size"], n_blocks=model_cfg["n_blocks"], n_heads=model_cfg["n_heads"], max_seq_len=model_cfg["max_seq_len"], dropout=model_cfg.get("dropout", 0.0), ).to(device) if world_size > 1: from torch.nn.parallel import DistributedDataParallel as DDP model = DDP( model, device_ids=[local_rank], static_graph=True, gradient_as_bucket_view=True, ) compile_mode = train_cfg.get("compile", "default") if compile_mode != "off": if is_main: print(f"[compile] torch.compile(mode={compile_mode!r}) — first step will be slow") model = torch.compile(model, mode=compile_mode, dynamic=False) optimizer = build_optimizer(config, model) if is_main: print(f"Model params: {count_parameters(model):,}") # ── Resume ──────────────────────────────────────────────────────────────── start_step = 0 if args.resume: ckpt = torch.load(args.resume, map_location=device) raw_model = _unwrap(model) raw_model.load_state_dict(ckpt["model"]) optimizer.load_state_dict(ckpt["optimizer"]) start_step = ckpt["step"] + 1 if is_main: print(f"Resumed from step {start_step}") train_loader, val_loader = build_dataloaders(config, tokenizer) train_iter = iter(train_loader) log_cfg = config.get("logging", {}) save_dir = Path(log_cfg.get("save_dir", "outputs/ar_baseline")) if is_main: save_dir.mkdir(parents=True, exist_ok=True) with open(save_dir / "config.yaml", "w") as f: yaml.dump(config, f) use_wandb = is_main and log_cfg.get("use_wandb", False) if use_wandb: try: import wandb wandb.init(project=log_cfg.get("project", "sad_ar_baseline"), config=config) except ImportError: use_wandb = False model.train() num_steps = train_cfg["num_steps"] grad_clip = train_cfg.get("grad_clip", 1.0) log_interval = train_cfg.get("log_interval", 100) eval_interval = train_cfg.get("eval_interval", 5000) save_interval = train_cfg.get("save_interval", 10000) last_metrics: dict = {} nan_skips = 0 if is_main and _has_tqdm: pbar = tqdm( total=num_steps, initial=start_step, dynamic_ncols=True, desc="AR baseline training", ) else: pbar = None t0 = time.time() for step in range(start_step, num_steps): lr = get_lr(step, config) for pg in optimizer.param_groups: pg["lr"] = lr try: full_batch = next(train_iter) except StopIteration: train_iter = iter(train_loader) full_batch = next(train_iter) batch = { "input_ids": full_batch["input_ids"].to(device, non_blocking=True), "attention_mask": full_batch["attention_mask"].to(device, non_blocking=True), } optimizer.zero_grad() loss, metrics = ar_step(batch, model, dtype) # Symmetric NaN-skip across DDP ranks (same pattern as train.py). finite_flag = torch.ones(1, device=device, dtype=torch.int32) if not torch.isfinite(loss): finite_flag.zero_() if world_size > 1: import torch.distributed as dist dist.all_reduce(finite_flag, op=dist.ReduceOp.MIN) if finite_flag.item() == 0: nan_skips += 1 if is_main: print(f"[WARN] step={step} skipped: non-finite loss " f"(total skips={nan_skips})") if use_wandb: import wandb wandb.log({"step": step, "nan_skips": nan_skips}) optimizer.zero_grad(set_to_none=True) if pbar is not None: pbar.update(1) continue loss.backward() if grad_clip > 0: nn.utils.clip_grad_norm_(list(model.parameters()), grad_clip) optimizer.step() last_metrics = metrics if pbar is not None: pbar.set_postfix( ce=fmt_metric(metrics["loss_ce"]), ppl=fmt_metric(metrics["ppl"]), lr=f"{lr:.1e}", ) pbar.update(1) if is_main and step % log_interval == 0: elapsed = time.time() - t0 print( f"step={step:6d} | " f"ce={fmt_metric(metrics['loss_ce'])} | " f"ppl={fmt_metric(metrics['ppl'])} | " f"lr={lr:.2e} | " f"t={elapsed:.1f}s" ) if use_wandb: import wandb wandb.log({ "step": step, "lr": lr, **{k: v.item() if hasattr(v, "item") else v for k, v in metrics.items()} }) t0 = time.time() if is_main and step % eval_interval == 0 and step > 0: val_metrics = evaluate(model, dtype, val_loader, device) print(" VAL | " + " | ".join( f"{k}={v:.4f}" for k, v in val_metrics.items() if k in ("loss_ce", "loss_total", "ppl") )) if use_wandb: import wandb wandb.log({"step": step, **{f"val/{k}": v for k, v in val_metrics.items()}}) if is_main and step % save_interval == 0 and step > 0: raw_model = _unwrap(model) save_checkpoint(step, raw_model, optimizer, config, save_dir, last_metrics) if is_main: raw_model = _unwrap(model) save_checkpoint(num_steps, raw_model, optimizer, config, save_dir, last_metrics) print("Training complete.") if pbar is not None: pbar.close() if world_size > 1: import torch.distributed as dist dist.destroy_process_group() if __name__ == "__main__": main()