| |
| """ |
| train_ar.py — Autoregressive (GPT-2 style) baseline training. |
| |
| Matches scripts/train.py in data pipeline, optimizer, and schedule; the only |
| differences are: |
| - Model: src.models.ar_model.ARModel (standard decoder, causal self-attn) |
| - Loss: shifted next-token cross-entropy, pad-masked by attention_mask |
| - No AncestorTable, no NoisyStateBuilder, no t-weighting |
| |
| Usage: |
| python scripts/train_ar.py --config configs/ar_owt.yaml |
| torchrun --nproc_per_node=8 scripts/train_ar.py --config configs/ar_owt.yaml |
| torchrun --nproc_per_node=8 scripts/train_ar.py \\ |
| --config configs/ar_owt.yaml \\ |
| --resume outputs/ar_baseline/latest.pt |
| """ |
|
|
| import sys |
| import os |
| import argparse |
| import math |
| import time |
| from pathlib import Path |
|
|
| ROOT = Path(__file__).resolve().parents[1] |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import yaml |
|
|
| sys.path.insert(0, str(ROOT)) |
|
|
| from src.utils import set_seed, count_parameters |
| from src.models.ar_model import ARModel |
| from src.data import build_debug_dataloader, build_owt_dataloader |
|
|
| try: |
| from tqdm import tqdm |
| _has_tqdm = True |
| except ImportError: |
| _has_tqdm = False |
|
|
|
|
| def _unwrap(model): |
| """Peel DDP (.module) and torch.compile (._orig_mod) wrappers down to ARModel.""" |
| while True: |
| if hasattr(model, "_orig_mod"): |
| model = model._orig_mod |
| elif hasattr(model, "module"): |
| model = model.module |
| else: |
| return model |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--config", default="configs/ar_owt.yaml") |
| p.add_argument("--resume", default=None, type=str) |
| p.add_argument("--num_steps", type=int, default=None) |
| p.add_argument("--batch_size", type=int, default=None) |
| p.add_argument("--local_rank", type=int, default=0) |
| return p.parse_args() |
|
|
|
|
| def load_config(path: str) -> dict: |
| with open(path) as f: |
| return yaml.safe_load(f) |
|
|
|
|
| def build_tokenizer(config: dict): |
| """Identical to train.py so the two runs consume the exact same token stream.""" |
| data_cfg = config.get("data", {}) |
| dataset = data_cfg.get("dataset", "debug") |
| vocab_size = config["model"]["vocab_size"] |
|
|
| if dataset == "debug": |
| class MockTokenizer: |
| def __init__(self, vocab_size): |
| self.vocab_size = vocab_size |
| self.pad_token_id = 0 |
| self.eos_token_id = 0 |
| self.bos_token_id = 0 |
| self.mask_token_id = vocab_size - 1 |
| self.model_max_length = config["model"]["max_seq_len"] |
| def __len__(self): |
| return self.vocab_size |
| return MockTokenizer(vocab_size) |
|
|
| from transformers import AutoTokenizer |
| tok = AutoTokenizer.from_pretrained( |
| ROOT / "tokenizers" / "gpt2", |
| local_files_only=True, |
| ) |
| if tok.eos_token is None: |
| tok.add_special_tokens({"eos_token": "<|endoftext|>"}) |
| if tok.bos_token is None: |
| tok.bos_token = tok.eos_token |
| if tok.pad_token is None: |
| tok.pad_token = tok.eos_token |
| |
| |
| config["model"]["vocab_size"] = len(tok) |
| return tok |
|
|
|
|
| def build_dataloaders(config: dict, tokenizer): |
| data_cfg = config.get("data", {}) |
| dataset = data_cfg.get("dataset", "debug") |
| seq_len = data_cfg.get("seq_len", 512) |
| batch_size = config["training"]["batch_size"] |
|
|
| if dataset == "debug": |
| train_loader = build_debug_dataloader( |
| vocab_size=config["model"]["vocab_size"], |
| seq_len=seq_len, |
| batch_size=batch_size, |
| num_samples=512, |
| mask_token_id=getattr(tokenizer, "mask_token_id", 0) or 0, |
| ) |
| val_loader = build_debug_dataloader( |
| vocab_size=config["model"]["vocab_size"], |
| seq_len=seq_len, |
| batch_size=batch_size, |
| num_samples=64, |
| mask_token_id=getattr(tokenizer, "mask_token_id", 0) or 0, |
| ) |
| elif dataset == "openwebtext": |
| mode = data_cfg.get("mode", "subsample") |
| train_loader = build_owt_dataloader( |
| tokenizer, |
| split="train[:-100000]", |
| seq_len=seq_len, |
| batch_size=batch_size, |
| num_workers=data_cfg.get("num_workers", 4), |
| cache_dir=data_cfg.get("cache_dir", None), |
| max_samples=data_cfg.get("max_train_samples", None), |
| mode=mode, |
| ) |
| val_loader = build_owt_dataloader( |
| tokenizer, |
| split="train[-100000:]", |
| seq_len=seq_len, |
| batch_size=batch_size, |
| num_workers=2, |
| cache_dir=data_cfg.get("cache_dir", None), |
| max_samples=data_cfg.get("max_val_samples", 100000), |
| mode=mode, |
| shard_across_ranks=False, |
| ) |
| else: |
| raise ValueError(f"Unknown dataset: {dataset}") |
| return train_loader, val_loader |
|
|
|
|
| def build_optimizer(config: dict, model: nn.Module): |
| train_cfg = config["training"] |
| betas = tuple(train_cfg.get("adam_betas", (0.9, 0.99))) |
| return torch.optim.AdamW( |
| list(model.parameters()), |
| lr=train_cfg["lr"], |
| weight_decay=train_cfg.get("weight_decay", 0.02), |
| betas=betas, |
| eps=train_cfg.get("adam_eps", 1e-9), |
| fused=True, |
| ) |
|
|
|
|
| def get_lr(step: int, config: dict) -> float: |
| """Linear warmup + cosine decay, identical to train.py.""" |
| train_cfg = config["training"] |
| num_steps = train_cfg["num_steps"] |
| warmup = train_cfg.get("warmup_steps", min(2000, num_steps // 100)) |
| lr_min = train_cfg.get("lr_min", train_cfg["lr"] * 0.1) |
| lr_max = train_cfg["lr"] |
| if step < warmup: |
| return lr_max * step / max(warmup, 1) |
| progress = (step - warmup) / max(num_steps - warmup, 1) |
| return lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * progress)) |
|
|
|
|
| def ar_step(batch: dict, model, dtype) -> tuple: |
| """ |
| Shifted next-token CE: |
| inputs = input_ids[:, :-1] |
| targets = input_ids[:, 1:] |
| loss is averaged over non-pad target positions. |
| """ |
| input_ids = batch["input_ids"] |
| attention_mask = batch["attention_mask"] |
| device = input_ids.device |
|
|
| autocast_device = "cuda" if device.type == "cuda" else "cpu" |
| with torch.autocast(device_type=autocast_device, dtype=dtype): |
| logits = model(input_ids=input_ids) |
|
|
| B, S, V = logits.shape |
| |
| logits_shift = logits[:, :-1, :].contiguous() |
| targets = input_ids[:, 1:].contiguous() |
| target_mask = attention_mask[:, 1:].float() |
|
|
| |
| ce = F.cross_entropy( |
| logits_shift.reshape(-1, V).float(), |
| targets.reshape(-1), |
| reduction="none", |
| ).reshape(B, S - 1) |
|
|
| total_valid = target_mask.sum().clamp(min=1) |
| loss = (ce * target_mask).sum() / total_valid |
|
|
| metrics = { |
| "loss_ce": loss.detach(), |
| "loss_total": loss.detach(), |
| "ppl": loss.detach().exp(), |
| "valid_tokens": total_valid.detach(), |
| } |
| return loss, metrics |
|
|
|
|
| def save_checkpoint(step, model, optimizer, config, save_dir: Path, metrics: dict): |
| save_dir.mkdir(parents=True, exist_ok=True) |
| ckpt = { |
| "step": step, |
| "model": model.state_dict(), |
| "optimizer": optimizer.state_dict(), |
| "config": config, |
| "metrics": {k: v.item() if hasattr(v, "item") else v for k, v in metrics.items()}, |
| } |
| torch.save(ckpt, save_dir / f"ckpt_{step}.pt") |
| torch.save(ckpt, save_dir / "latest.pt") |
| print(f" Saved checkpoint: {save_dir}/ckpt_{step}.pt") |
|
|
|
|
| @torch.no_grad() |
| def evaluate(model, dtype, val_loader, device, num_batches: int = 50) -> dict: |
| model.eval() |
| totals: dict = {} |
| count = 0 |
| for batch in val_loader: |
| if count >= num_batches: |
| break |
| batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} |
| _, metrics = ar_step(batch, model, dtype) |
| for k, v in metrics.items(): |
| val = v.item() if hasattr(v, "item") else float(v) |
| totals[k] = totals.get(k, 0.0) + val |
| count += 1 |
| model.train() |
| return {k: v / max(count, 1) for k, v in totals.items()} |
|
|
|
|
| def fmt_metric(v) -> str: |
| v = v.item() if hasattr(v, "item") else float(v) |
| return f"{v:.4f}" |
|
|
|
|
| def main(): |
| args = parse_args() |
| config = load_config(args.config) |
| if args.num_steps is not None: |
| config["training"]["num_steps"] = args.num_steps |
| if args.batch_size is not None: |
| config["training"]["batch_size"] = args.batch_size |
|
|
| local_rank = int(os.environ.get("LOCAL_RANK", args.local_rank)) |
| world_size = int(os.environ.get("WORLD_SIZE", 1)) |
| is_main = (local_rank == 0) |
|
|
| if world_size > 1: |
| import torch.distributed as dist |
| dist.init_process_group("nccl") |
| device = torch.device(f"cuda:{local_rank}") |
| torch.cuda.set_device(device) |
| elif torch.cuda.is_available(): |
| device = torch.device("cuda") |
| else: |
| device = torch.device("cpu") |
|
|
| if is_main: |
| print(f"Device: {device} world_size: {world_size}") |
|
|
| train_cfg = config["training"] |
| set_seed(train_cfg.get("seed", 42) + local_rank) |
|
|
| dtype_str = train_cfg.get("dtype", "bf16") |
| dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[dtype_str] |
|
|
| tokenizer = build_tokenizer(config) |
|
|
| model_cfg = config["model"] |
| model = ARModel( |
| vocab_size=model_cfg["vocab_size"], |
| hidden_size=model_cfg["hidden_size"], |
| n_blocks=model_cfg["n_blocks"], |
| n_heads=model_cfg["n_heads"], |
| max_seq_len=model_cfg["max_seq_len"], |
| dropout=model_cfg.get("dropout", 0.0), |
| ).to(device) |
|
|
| if world_size > 1: |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| model = DDP( |
| model, |
| device_ids=[local_rank], |
| static_graph=True, |
| gradient_as_bucket_view=True, |
| ) |
|
|
| compile_mode = train_cfg.get("compile", "default") |
| if compile_mode != "off": |
| if is_main: |
| print(f"[compile] torch.compile(mode={compile_mode!r}) — first step will be slow") |
| model = torch.compile(model, mode=compile_mode, dynamic=False) |
|
|
| optimizer = build_optimizer(config, model) |
|
|
| if is_main: |
| print(f"Model params: {count_parameters(model):,}") |
|
|
| |
| start_step = 0 |
| if args.resume: |
| ckpt = torch.load(args.resume, map_location=device) |
| raw_model = _unwrap(model) |
| raw_model.load_state_dict(ckpt["model"]) |
| optimizer.load_state_dict(ckpt["optimizer"]) |
| start_step = ckpt["step"] + 1 |
| if is_main: |
| print(f"Resumed from step {start_step}") |
|
|
| train_loader, val_loader = build_dataloaders(config, tokenizer) |
| train_iter = iter(train_loader) |
|
|
| log_cfg = config.get("logging", {}) |
| save_dir = Path(log_cfg.get("save_dir", "outputs/ar_baseline")) |
| if is_main: |
| save_dir.mkdir(parents=True, exist_ok=True) |
| with open(save_dir / "config.yaml", "w") as f: |
| yaml.dump(config, f) |
|
|
| use_wandb = is_main and log_cfg.get("use_wandb", False) |
| if use_wandb: |
| try: |
| import wandb |
| wandb.init(project=log_cfg.get("project", "sad_ar_baseline"), config=config) |
| except ImportError: |
| use_wandb = False |
|
|
| model.train() |
| num_steps = train_cfg["num_steps"] |
| grad_clip = train_cfg.get("grad_clip", 1.0) |
| log_interval = train_cfg.get("log_interval", 100) |
| eval_interval = train_cfg.get("eval_interval", 5000) |
| save_interval = train_cfg.get("save_interval", 10000) |
|
|
| last_metrics: dict = {} |
| nan_skips = 0 |
|
|
| if is_main and _has_tqdm: |
| pbar = tqdm( |
| total=num_steps, |
| initial=start_step, |
| dynamic_ncols=True, |
| desc="AR baseline training", |
| ) |
| else: |
| pbar = None |
|
|
| t0 = time.time() |
|
|
| for step in range(start_step, num_steps): |
| lr = get_lr(step, config) |
| for pg in optimizer.param_groups: |
| pg["lr"] = lr |
|
|
| try: |
| full_batch = next(train_iter) |
| except StopIteration: |
| train_iter = iter(train_loader) |
| full_batch = next(train_iter) |
|
|
| batch = { |
| "input_ids": full_batch["input_ids"].to(device, non_blocking=True), |
| "attention_mask": full_batch["attention_mask"].to(device, non_blocking=True), |
| } |
|
|
| optimizer.zero_grad() |
| loss, metrics = ar_step(batch, model, dtype) |
|
|
| |
| finite_flag = torch.ones(1, device=device, dtype=torch.int32) |
| if not torch.isfinite(loss): |
| finite_flag.zero_() |
| if world_size > 1: |
| import torch.distributed as dist |
| dist.all_reduce(finite_flag, op=dist.ReduceOp.MIN) |
| if finite_flag.item() == 0: |
| nan_skips += 1 |
| if is_main: |
| print(f"[WARN] step={step} skipped: non-finite loss " |
| f"(total skips={nan_skips})") |
| if use_wandb: |
| import wandb |
| wandb.log({"step": step, "nan_skips": nan_skips}) |
| optimizer.zero_grad(set_to_none=True) |
| if pbar is not None: |
| pbar.update(1) |
| continue |
|
|
| loss.backward() |
|
|
| if grad_clip > 0: |
| nn.utils.clip_grad_norm_(list(model.parameters()), grad_clip) |
| optimizer.step() |
| last_metrics = metrics |
|
|
| if pbar is not None: |
| pbar.set_postfix( |
| ce=fmt_metric(metrics["loss_ce"]), |
| ppl=fmt_metric(metrics["ppl"]), |
| lr=f"{lr:.1e}", |
| ) |
| pbar.update(1) |
|
|
| if is_main and step % log_interval == 0: |
| elapsed = time.time() - t0 |
| print( |
| f"step={step:6d} | " |
| f"ce={fmt_metric(metrics['loss_ce'])} | " |
| f"ppl={fmt_metric(metrics['ppl'])} | " |
| f"lr={lr:.2e} | " |
| f"t={elapsed:.1f}s" |
| ) |
| if use_wandb: |
| import wandb |
| wandb.log({ |
| "step": step, "lr": lr, |
| **{k: v.item() if hasattr(v, "item") else v |
| for k, v in metrics.items()} |
| }) |
| t0 = time.time() |
|
|
| if is_main and step % eval_interval == 0 and step > 0: |
| val_metrics = evaluate(model, dtype, val_loader, device) |
| print(" VAL | " + " | ".join( |
| f"{k}={v:.4f}" for k, v in val_metrics.items() |
| if k in ("loss_ce", "loss_total", "ppl") |
| )) |
| if use_wandb: |
| import wandb |
| wandb.log({"step": step, |
| **{f"val/{k}": v for k, v in val_metrics.items()}}) |
|
|
| if is_main and step % save_interval == 0 and step > 0: |
| raw_model = _unwrap(model) |
| save_checkpoint(step, raw_model, optimizer, config, save_dir, last_metrics) |
|
|
| if is_main: |
| raw_model = _unwrap(model) |
| save_checkpoint(num_steps, raw_model, optimizer, config, save_dir, last_metrics) |
| print("Training complete.") |
| if pbar is not None: |
| pbar.close() |
|
|
| if world_size > 1: |
| import torch.distributed as dist |
| dist.destroy_process_group() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|