sad / scripts /train_ar.py
haochengsama's picture
Add files using upload-large-folder tool
8b0aeb2 verified
Raw
History Blame Contribute Delete
16 kB
#!/usr/bin/env python3
"""
train_ar.py — Autoregressive (GPT-2 style) baseline training.
Matches scripts/train.py in data pipeline, optimizer, and schedule; the only
differences are:
- Model: src.models.ar_model.ARModel (standard decoder, causal self-attn)
- Loss: shifted next-token cross-entropy, pad-masked by attention_mask
- No AncestorTable, no NoisyStateBuilder, no t-weighting
Usage:
python scripts/train_ar.py --config configs/ar_owt.yaml
torchrun --nproc_per_node=8 scripts/train_ar.py --config configs/ar_owt.yaml
torchrun --nproc_per_node=8 scripts/train_ar.py \\
--config configs/ar_owt.yaml \\
--resume outputs/ar_baseline/latest.pt
"""
import sys
import os
import argparse
import math
import time
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1] # sad/
import torch
import torch.nn as nn
import torch.nn.functional as F
import yaml
sys.path.insert(0, str(ROOT))
from src.utils import set_seed, count_parameters
from src.models.ar_model import ARModel
from src.data import build_debug_dataloader, build_owt_dataloader
try:
from tqdm import tqdm
_has_tqdm = True
except ImportError:
_has_tqdm = False
def _unwrap(model):
"""Peel DDP (.module) and torch.compile (._orig_mod) wrappers down to ARModel."""
while True:
if hasattr(model, "_orig_mod"):
model = model._orig_mod
elif hasattr(model, "module"):
model = model.module
else:
return model
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--config", default="configs/ar_owt.yaml")
p.add_argument("--resume", default=None, type=str)
p.add_argument("--num_steps", type=int, default=None)
p.add_argument("--batch_size", type=int, default=None)
p.add_argument("--local_rank", type=int, default=0)
return p.parse_args()
def load_config(path: str) -> dict:
with open(path) as f:
return yaml.safe_load(f)
def build_tokenizer(config: dict):
"""Identical to train.py so the two runs consume the exact same token stream."""
data_cfg = config.get("data", {})
dataset = data_cfg.get("dataset", "debug")
vocab_size = config["model"]["vocab_size"]
if dataset == "debug":
class MockTokenizer:
def __init__(self, vocab_size):
self.vocab_size = vocab_size
self.pad_token_id = 0
self.eos_token_id = 0
self.bos_token_id = 0
self.mask_token_id = vocab_size - 1
self.model_max_length = config["model"]["max_seq_len"]
def __len__(self):
return self.vocab_size
return MockTokenizer(vocab_size)
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(
ROOT / "tokenizers" / "gpt2",
local_files_only=True,
)
if tok.eos_token is None:
tok.add_special_tokens({"eos_token": "<|endoftext|>"})
if tok.bos_token is None:
tok.bos_token = tok.eos_token
if tok.pad_token is None:
tok.pad_token = tok.eos_token
# AR baseline does not need [MASK], but the shared OWT dataloader builds
# the same token stream regardless — so no special handling required.
config["model"]["vocab_size"] = len(tok)
return tok
def build_dataloaders(config: dict, tokenizer):
data_cfg = config.get("data", {})
dataset = data_cfg.get("dataset", "debug")
seq_len = data_cfg.get("seq_len", 512)
batch_size = config["training"]["batch_size"]
if dataset == "debug":
train_loader = build_debug_dataloader(
vocab_size=config["model"]["vocab_size"],
seq_len=seq_len,
batch_size=batch_size,
num_samples=512,
mask_token_id=getattr(tokenizer, "mask_token_id", 0) or 0,
)
val_loader = build_debug_dataloader(
vocab_size=config["model"]["vocab_size"],
seq_len=seq_len,
batch_size=batch_size,
num_samples=64,
mask_token_id=getattr(tokenizer, "mask_token_id", 0) or 0,
)
elif dataset == "openwebtext":
mode = data_cfg.get("mode", "subsample")
train_loader = build_owt_dataloader(
tokenizer,
split="train[:-100000]",
seq_len=seq_len,
batch_size=batch_size,
num_workers=data_cfg.get("num_workers", 4),
cache_dir=data_cfg.get("cache_dir", None),
max_samples=data_cfg.get("max_train_samples", None),
mode=mode,
)
val_loader = build_owt_dataloader(
tokenizer,
split="train[-100000:]",
seq_len=seq_len,
batch_size=batch_size,
num_workers=2,
cache_dir=data_cfg.get("cache_dir", None),
max_samples=data_cfg.get("max_val_samples", 100000),
mode=mode,
shard_across_ranks=False,
)
else:
raise ValueError(f"Unknown dataset: {dataset}")
return train_loader, val_loader
def build_optimizer(config: dict, model: nn.Module):
train_cfg = config["training"]
betas = tuple(train_cfg.get("adam_betas", (0.9, 0.99)))
return torch.optim.AdamW(
list(model.parameters()),
lr=train_cfg["lr"],
weight_decay=train_cfg.get("weight_decay", 0.02),
betas=betas,
eps=train_cfg.get("adam_eps", 1e-9),
fused=True,
)
def get_lr(step: int, config: dict) -> float:
"""Linear warmup + cosine decay, identical to train.py."""
train_cfg = config["training"]
num_steps = train_cfg["num_steps"]
warmup = train_cfg.get("warmup_steps", min(2000, num_steps // 100))
lr_min = train_cfg.get("lr_min", train_cfg["lr"] * 0.1)
lr_max = train_cfg["lr"]
if step < warmup:
return lr_max * step / max(warmup, 1)
progress = (step - warmup) / max(num_steps - warmup, 1)
return lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * progress))
def ar_step(batch: dict, model, dtype) -> tuple:
"""
Shifted next-token CE:
inputs = input_ids[:, :-1]
targets = input_ids[:, 1:]
loss is averaged over non-pad target positions.
"""
input_ids = batch["input_ids"] # [B, S]
attention_mask = batch["attention_mask"] # [B, S]
device = input_ids.device
autocast_device = "cuda" if device.type == "cuda" else "cpu"
with torch.autocast(device_type=autocast_device, dtype=dtype):
logits = model(input_ids=input_ids) # [B, S, V]
B, S, V = logits.shape
# Shift-by-one: position i predicts token i+1.
logits_shift = logits[:, :-1, :].contiguous() # [B, S-1, V]
targets = input_ids[:, 1:].contiguous() # [B, S-1]
target_mask = attention_mask[:, 1:].float() # [B, S-1]
# fp32 CE for bf16 safety (same rationale as SADLoss).
ce = F.cross_entropy(
logits_shift.reshape(-1, V).float(),
targets.reshape(-1),
reduction="none",
).reshape(B, S - 1)
total_valid = target_mask.sum().clamp(min=1)
loss = (ce * target_mask).sum() / total_valid
metrics = {
"loss_ce": loss.detach(),
"loss_total": loss.detach(),
"ppl": loss.detach().exp(),
"valid_tokens": total_valid.detach(),
}
return loss, metrics
def save_checkpoint(step, model, optimizer, config, save_dir: Path, metrics: dict):
save_dir.mkdir(parents=True, exist_ok=True)
ckpt = {
"step": step,
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"config": config,
"metrics": {k: v.item() if hasattr(v, "item") else v for k, v in metrics.items()},
}
torch.save(ckpt, save_dir / f"ckpt_{step}.pt")
torch.save(ckpt, save_dir / "latest.pt")
print(f" Saved checkpoint: {save_dir}/ckpt_{step}.pt")
@torch.no_grad()
def evaluate(model, dtype, val_loader, device, num_batches: int = 50) -> dict:
model.eval()
totals: dict = {}
count = 0
for batch in val_loader:
if count >= num_batches:
break
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
_, metrics = ar_step(batch, model, dtype)
for k, v in metrics.items():
val = v.item() if hasattr(v, "item") else float(v)
totals[k] = totals.get(k, 0.0) + val
count += 1
model.train()
return {k: v / max(count, 1) for k, v in totals.items()}
def fmt_metric(v) -> str:
v = v.item() if hasattr(v, "item") else float(v)
return f"{v:.4f}"
def main():
args = parse_args()
config = load_config(args.config)
if args.num_steps is not None:
config["training"]["num_steps"] = args.num_steps
if args.batch_size is not None:
config["training"]["batch_size"] = args.batch_size
local_rank = int(os.environ.get("LOCAL_RANK", args.local_rank))
world_size = int(os.environ.get("WORLD_SIZE", 1))
is_main = (local_rank == 0)
if world_size > 1:
import torch.distributed as dist
dist.init_process_group("nccl")
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
elif torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
if is_main:
print(f"Device: {device} world_size: {world_size}")
train_cfg = config["training"]
set_seed(train_cfg.get("seed", 42) + local_rank)
dtype_str = train_cfg.get("dtype", "bf16")
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[dtype_str]
tokenizer = build_tokenizer(config)
model_cfg = config["model"]
model = ARModel(
vocab_size=model_cfg["vocab_size"],
hidden_size=model_cfg["hidden_size"],
n_blocks=model_cfg["n_blocks"],
n_heads=model_cfg["n_heads"],
max_seq_len=model_cfg["max_seq_len"],
dropout=model_cfg.get("dropout", 0.0),
).to(device)
if world_size > 1:
from torch.nn.parallel import DistributedDataParallel as DDP
model = DDP(
model,
device_ids=[local_rank],
static_graph=True,
gradient_as_bucket_view=True,
)
compile_mode = train_cfg.get("compile", "default")
if compile_mode != "off":
if is_main:
print(f"[compile] torch.compile(mode={compile_mode!r}) — first step will be slow")
model = torch.compile(model, mode=compile_mode, dynamic=False)
optimizer = build_optimizer(config, model)
if is_main:
print(f"Model params: {count_parameters(model):,}")
# ── Resume ────────────────────────────────────────────────────────────────
start_step = 0
if args.resume:
ckpt = torch.load(args.resume, map_location=device)
raw_model = _unwrap(model)
raw_model.load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])
start_step = ckpt["step"] + 1
if is_main:
print(f"Resumed from step {start_step}")
train_loader, val_loader = build_dataloaders(config, tokenizer)
train_iter = iter(train_loader)
log_cfg = config.get("logging", {})
save_dir = Path(log_cfg.get("save_dir", "outputs/ar_baseline"))
if is_main:
save_dir.mkdir(parents=True, exist_ok=True)
with open(save_dir / "config.yaml", "w") as f:
yaml.dump(config, f)
use_wandb = is_main and log_cfg.get("use_wandb", False)
if use_wandb:
try:
import wandb
wandb.init(project=log_cfg.get("project", "sad_ar_baseline"), config=config)
except ImportError:
use_wandb = False
model.train()
num_steps = train_cfg["num_steps"]
grad_clip = train_cfg.get("grad_clip", 1.0)
log_interval = train_cfg.get("log_interval", 100)
eval_interval = train_cfg.get("eval_interval", 5000)
save_interval = train_cfg.get("save_interval", 10000)
last_metrics: dict = {}
nan_skips = 0
if is_main and _has_tqdm:
pbar = tqdm(
total=num_steps,
initial=start_step,
dynamic_ncols=True,
desc="AR baseline training",
)
else:
pbar = None
t0 = time.time()
for step in range(start_step, num_steps):
lr = get_lr(step, config)
for pg in optimizer.param_groups:
pg["lr"] = lr
try:
full_batch = next(train_iter)
except StopIteration:
train_iter = iter(train_loader)
full_batch = next(train_iter)
batch = {
"input_ids": full_batch["input_ids"].to(device, non_blocking=True),
"attention_mask": full_batch["attention_mask"].to(device, non_blocking=True),
}
optimizer.zero_grad()
loss, metrics = ar_step(batch, model, dtype)
# Symmetric NaN-skip across DDP ranks (same pattern as train.py).
finite_flag = torch.ones(1, device=device, dtype=torch.int32)
if not torch.isfinite(loss):
finite_flag.zero_()
if world_size > 1:
import torch.distributed as dist
dist.all_reduce(finite_flag, op=dist.ReduceOp.MIN)
if finite_flag.item() == 0:
nan_skips += 1
if is_main:
print(f"[WARN] step={step} skipped: non-finite loss "
f"(total skips={nan_skips})")
if use_wandb:
import wandb
wandb.log({"step": step, "nan_skips": nan_skips})
optimizer.zero_grad(set_to_none=True)
if pbar is not None:
pbar.update(1)
continue
loss.backward()
if grad_clip > 0:
nn.utils.clip_grad_norm_(list(model.parameters()), grad_clip)
optimizer.step()
last_metrics = metrics
if pbar is not None:
pbar.set_postfix(
ce=fmt_metric(metrics["loss_ce"]),
ppl=fmt_metric(metrics["ppl"]),
lr=f"{lr:.1e}",
)
pbar.update(1)
if is_main and step % log_interval == 0:
elapsed = time.time() - t0
print(
f"step={step:6d} | "
f"ce={fmt_metric(metrics['loss_ce'])} | "
f"ppl={fmt_metric(metrics['ppl'])} | "
f"lr={lr:.2e} | "
f"t={elapsed:.1f}s"
)
if use_wandb:
import wandb
wandb.log({
"step": step, "lr": lr,
**{k: v.item() if hasattr(v, "item") else v
for k, v in metrics.items()}
})
t0 = time.time()
if is_main and step % eval_interval == 0 and step > 0:
val_metrics = evaluate(model, dtype, val_loader, device)
print(" VAL | " + " | ".join(
f"{k}={v:.4f}" for k, v in val_metrics.items()
if k in ("loss_ce", "loss_total", "ppl")
))
if use_wandb:
import wandb
wandb.log({"step": step,
**{f"val/{k}": v for k, v in val_metrics.items()}})
if is_main and step % save_interval == 0 and step > 0:
raw_model = _unwrap(model)
save_checkpoint(step, raw_model, optimizer, config, save_dir, last_metrics)
if is_main:
raw_model = _unwrap(model)
save_checkpoint(num_steps, raw_model, optimizer, config, save_dir, last_metrics)
print("Training complete.")
if pbar is not None:
pbar.close()
if world_size > 1:
import torch.distributed as dist
dist.destroy_process_group()
if __name__ == "__main__":
main()