#!/usr/bin/env python3 """ train_sad.py – SAD training script. SAD training using SADModel.forward_vectorized (block-diff attention mask + flex attention when available). - Each step trains on the full seq_len (no curriculum). - forward_vectorized: concatenates [noisy|clean], applies block-diff mask. Usage: python scripts/train_sad.py --config configs/sad_owt.yaml torchrun --nproc_per_node=8 scripts/train_sad.py --config configs/sad_owt.yaml torchrun --nproc_per_node=8 scripts/train_sad.py \\ --config configs/sad_owt.yaml \\ --resume outputs/sad/latest.pt """ import sys import os import argparse import math import time from pathlib import Path ROOT = Path(__file__).resolve().parents[1] # sad/ import torch import torch.nn as nn import torch.nn.functional as F import yaml sys.path.insert(0, str(ROOT)) from src.utils import set_seed, count_parameters, grad_norm from src.models.sad_model import SADModel from src.diffusion.ancestor_table import AncestorTable from src.diffusion.noisy_state import NoisyStateBuilder from src.losses.sad_loss import SADLoss from src.data import build_debug_dataloader, build_owt_dataloader try: from tqdm import tqdm _has_tqdm = True except ImportError: _has_tqdm = False def _unwrap(model): """Peel DDP (.module) and torch.compile (._orig_mod) wrappers down to SADModel.""" while True: if hasattr(model, "_orig_mod"): model = model._orig_mod elif hasattr(model, "module"): model = model.module else: return model def parse_args(): p = argparse.ArgumentParser() p.add_argument("--config", default="configs/sad_owt.yaml") p.add_argument("--resume", default=None, type=str) p.add_argument("--num_steps", type=int, default=None, help="Override training.num_steps in config") p.add_argument("--batch_size", type=int, default=None, help="Override training.batch_size (per-GPU) in config") p.add_argument("--local_rank", type=int, default=0) return p.parse_args() def load_config(path: str) -> dict: with open(path) as f: return yaml.safe_load(f) def build_tokenizer(config: dict): data_cfg = config.get("data", {}) dataset = data_cfg.get("dataset", "debug") vocab_size = config["model"]["vocab_size"] if dataset == "debug": class MockTokenizer: def __init__(self, vocab_size, mask_token_id): self.vocab_size = vocab_size self.mask_token_id = mask_token_id self.pad_token_id = 0 self.eos_token_id = 0 # debug dataset has no real pads; match pad==eos self.model_max_length = config["model"]["max_seq_len"] def __len__(self): return self.vocab_size return MockTokenizer(vocab_size, vocab_size - 1) else: from transformers import AutoTokenizer tok = AutoTokenizer.from_pretrained( ROOT / "tokenizers" / "gpt2", local_files_only=True, ) # 本地 tokenizer_config.json 可能没定义 special tokens;显式登记 <|endoftext|> if tok.eos_token is None: tok.add_special_tokens({"eos_token": "<|endoftext|>"}) if tok.bos_token is None: tok.bos_token = tok.eos_token if tok.pad_token is None: tok.pad_token = tok.eos_token if tok.mask_token_id is None: tok.add_special_tokens({"mask_token": "[MASK]"}) config["model"]["vocab_size"] = len(tok) if "level_sizes" in config["model"]: config["model"]["level_sizes"][0] = len(tok) return tok def build_dataloaders(config: dict, tokenizer): data_cfg = config.get("data", {}) dataset = data_cfg.get("dataset", "debug") seq_len = data_cfg.get("seq_len", 512) batch_size = config["training"]["batch_size"] if dataset == "debug": train_loader = build_debug_dataloader( vocab_size=config["model"]["vocab_size"], seq_len=seq_len, batch_size=batch_size, num_samples=512, mask_token_id=tokenizer.mask_token_id, ) val_loader = build_debug_dataloader( vocab_size=config["model"]["vocab_size"], seq_len=seq_len, batch_size=batch_size, num_samples=64, mask_token_id=tokenizer.mask_token_id, ) elif dataset == "openwebtext": mode = data_cfg.get("mode", "subsample") train_loader = build_owt_dataloader( tokenizer, split="train[:-100000]", seq_len=seq_len, batch_size=batch_size, num_workers=data_cfg.get("num_workers", 4), cache_dir=data_cfg.get("cache_dir", None), max_samples=data_cfg.get("max_train_samples", None), mode=mode, ) val_loader = build_owt_dataloader( tokenizer, split="train[-100000:]", seq_len=seq_len, batch_size=batch_size, num_workers=2, cache_dir=data_cfg.get("cache_dir", None), max_samples=data_cfg.get("max_val_samples", 100000), mode=mode, shard_across_ranks=False, # eval runs on rank 0 only — don't shard ) else: raise ValueError(f"Unknown dataset: {dataset}") return train_loader, val_loader def build_ancestor_table(config: dict, device, embed_dim: int) -> AncestorTable: ancestor_cfg = config.get("ancestor", {}) script_dir = ROOT lut_path = ancestor_cfg.get("lut_path", None) if lut_path is None: # Debug mode: generate a random LUT for the configured vocab_size. # Use an independent Generator seeded from config so every rank sees # the same LUT — the global RNG has already been perturbed by # `set_seed(seed + local_rank)` in main(). vocab_size = config["model"]["vocab_size"] K = ancestor_cfg.get("num_clusters", 64) top_k = ancestor_cfg.get("top_k", 3) seed = config.get("training", {}).get("seed", 42) print(f"[AncestorTable] No lut_path configured – generating random LUT " f"(V={vocab_size}, K={K}, top_k={top_k}, seed={seed})") g = torch.Generator().manual_seed(seed) indices = torch.randint(0, K, (vocab_size, top_k), generator=g) raw_w = torch.rand(vocab_size, top_k, generator=g) probs = raw_w / raw_w.sum(dim=-1, keepdim=True) init_emb = torch.randn(K, embed_dim, generator=g) * 0.02 return AncestorTable( lut_indices=[indices], lut_probs=[probs], init_embeddings=[init_emb], ).to(device) lut_path = Path(lut_path) if Path(lut_path).is_absolute() else script_dir / lut_path proto_path = ancestor_cfg.get("proto_path", None) if proto_path is not None: proto_path = Path(proto_path) if Path(proto_path).is_absolute() else script_dir / proto_path table = AncestorTable.from_files( lut_path=lut_path, proto_path=proto_path, embed_dim=embed_dim, device=device, ) return table.to(device) def build_optimizer(config: dict, model: nn.Module, ancestor_table: AncestorTable): train_cfg = config["training"] params = list(model.parameters()) + list(ancestor_table.parameters()) betas = tuple(train_cfg.get("adam_betas", (0.9, 0.99))) return torch.optim.AdamW( params, lr=train_cfg["lr"], weight_decay=train_cfg.get("weight_decay", 0.02), betas=betas, eps=train_cfg.get("adam_eps", 1e-9), fused=True, ) def get_lr(step: int, config: dict) -> float: train_cfg = config["training"] num_steps = train_cfg["num_steps"] warmup = train_cfg.get("warmup_steps", min(2000, num_steps // 100)) lr_min = train_cfg.get("lr_min", train_cfg["lr"] * 0.1) lr_max = train_cfg["lr"] if step < warmup: return lr_max * step / max(warmup, 1) progress = (step - warmup) / max(num_steps - warmup, 1) return lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * progress)) def block_ar_step( batch: dict, model, ancestor_table: AncestorTable, loss_fn: SADLoss, noisy_builder: NoisyStateBuilder, tokenizer, dtype, ) -> tuple: """ One Block-AR training step using forward_vectorized. Builds clean_embs and noisy_embs from the batch, calls model.forward_vectorized(noisy_embs, clean_embs), computes SAD loss. """ device = batch["input_ids"].device input_ids = batch["input_ids"] # [B, L] attention_mask = batch["attention_mask"] # [B, L] B, L = input_ids.shape autocast_device = "cuda" if device.type == "cuda" else "cpu" # DDP 下 get_leaf_embeddings 只是取参数 tensor,不涉及 grad hook;通过 .module 解包即可。 # 正向计算必须走 model(...) 才能触发 DDP 的梯度同步。 raw_model = _unwrap(model) with torch.autocast(device_type=autocast_device, dtype=dtype): leaf_emb = raw_model.get_leaf_embeddings() # [V, d] mask_emb = leaf_emb[tokenizer.mask_token_id] # [d] clean_embs = leaf_emb[input_ids] # [B, L, d] # HDLM γ=1 schedule: one t per sequence, per-token 3-state sampling. t = noisy_builder.sample_t(B, device=device) # [B] levels = noisy_builder.sample_levels_hdlm( t, L, num_ancestor_levels=ancestor_table.num_levels, ) # [B, L] noisy_embs, ancestor_log_probs, ancestor_probs_per_lvl, corrupt_mask = \ noisy_builder.build_noisy_embeddings( input_ids, levels, ancestor_table, leaf_emb, mask_emb ) # Always pass attention_mask — branching on `(attention_mask == 0).any()` # would force a GPU→CPU sync every step. The mask-add cost is negligible. leaf_logits = model( noisy_embs=noisy_embs, clean_embs=clean_embs, attention_mask=attention_mask, ) # [B, L, V] loss, metrics = loss_fn( leaf_logits=leaf_logits, input_ids=input_ids, levels=levels, attention_mask=attention_mask, t=t, ancestor_log_probs=ancestor_log_probs, ancestor_probs_per_level=ancestor_probs_per_lvl, corrupt_mask=corrupt_mask, ) metrics["mean_level"] = levels.float().mean().detach() metrics["mean_t"] = t.float().mean().detach() metrics["logits_absmax"] = leaf_logits.detach().abs().max() return loss, metrics def save_checkpoint(step, model, ancestor_table, optimizer, config, save_dir: Path, metrics: dict): save_dir.mkdir(parents=True, exist_ok=True) ckpt = { "step": step, "model": model.state_dict(), "ancestor_table": ancestor_table.state_dict(), "optimizer": optimizer.state_dict(), "config": config, "metrics": {k: v.item() if hasattr(v, 'item') else v for k, v in metrics.items()}, } torch.save(ckpt, save_dir / f"ckpt_{step}.pt") torch.save(ckpt, save_dir / "latest.pt") print(f" Saved checkpoint: {save_dir}/ckpt_{step}.pt") @torch.no_grad() def evaluate(model, ancestor_table, loss_fn, noisy_builder, tokenizer, dtype, val_loader, device, num_batches: int = 50) -> dict: model.eval() total_metrics = {} count = 0 for batch in val_loader: if count >= num_batches: break batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} _, metrics = block_ar_step( batch, model, ancestor_table, loss_fn, noisy_builder, tokenizer, dtype, ) for k, v in metrics.items(): val = v.item() if hasattr(v, 'item') else float(v) total_metrics[k] = total_metrics.get(k, 0.0) + val count += 1 model.train() return {k: v / max(count, 1) for k, v in total_metrics.items()} def fmt_metric(v) -> str: v = v.item() if hasattr(v, 'item') else float(v) return f"{v:.4f}" def main(): args = parse_args() config = load_config(args.config) if args.num_steps is not None: config["training"]["num_steps"] = args.num_steps if args.batch_size is not None: config["training"]["batch_size"] = args.batch_size local_rank = int(os.environ.get("LOCAL_RANK", args.local_rank)) world_size = int(os.environ.get("WORLD_SIZE", 1)) is_main = (local_rank == 0) if world_size > 1: import torch.distributed as dist dist.init_process_group("nccl") device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(device) elif torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") if is_main: print(f"Device: {device} world_size: {world_size}") train_cfg = config["training"] set_seed(train_cfg.get("seed", 42) + local_rank) dtype_str = train_cfg.get("dtype", "bf16") dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[dtype_str] tokenizer = build_tokenizer(config) # Flex-attention path in SADModel.forward_vectorized ignores the padding # mask under the assumption pad==eos (so attending to pads is harmless). # Guard that assumption here so a future pad-token change fails loudly. assert tokenizer.pad_token_id == tokenizer.eos_token_id, ( f"forward_vectorized flex path assumes pad_token_id == eos_token_id, " f"got pad={tokenizer.pad_token_id}, eos={tokenizer.eos_token_id}. " f"See TODO in sad_model.py::forward_vectorized for packing support." ) model_cfg = config["model"] model = SADModel( vocab_size=model_cfg["vocab_size"], hidden_size=model_cfg["hidden_size"], n_blocks=model_cfg["n_blocks"], n_heads=model_cfg["n_heads"], cond_dim=model_cfg["cond_dim"], max_seq_len=model_cfg["max_seq_len"], block_size=model_cfg.get("block_size", 16), dropout=model_cfg.get("dropout", 0.0), num_levels=model_cfg.get("num_levels", 2), level_sizes=model_cfg.get("level_sizes"), tie_weights=model_cfg.get("tie_weights", False), ).to(device) ancestor_table = build_ancestor_table(config, device, embed_dim=model_cfg["hidden_size"]) # AncestorTable is not wrapped in DDP (which would auto-broadcast init # params). Per-rank set_seed() means any random init inside build/from_files # diverges across ranks. Broadcast rank 0's state so all ranks start from # the same parameters — grad all-reduce alone cannot undo an init mismatch. if world_size > 1: import torch.distributed as dist for p in ancestor_table.parameters(): dist.broadcast(p.data, src=0) for b in ancestor_table.buffers(): dist.broadcast(b.data, src=0) loss_cfg = config.get("loss", {}) loss_fn = SADLoss( vocab_size=model_cfg["vocab_size"], lambda_ancestor=loss_cfg.get("lambda_ancestor", 0.0), ancestor_table=ancestor_table if loss_cfg.get("lambda_ancestor", 0.0) > 0 else None, mask_only=loss_cfg.get("mask_only", True), ).to(device) noisy_builder = NoisyStateBuilder( vocab_size=model_cfg["vocab_size"], mask_token_id=tokenizer.mask_token_id, ) if world_size > 1: from torch.nn.parallel import DistributedDataParallel as DDP model = DDP( model, device_ids=[local_rank], static_graph=True, gradient_as_bucket_view=True, ) # torch.compile for whole-graph kernel fusion. Compiled FlexAttention inside # DDiTBlockWithMask will be traced as part of the same graph. compile_mode = train_cfg.get("compile", "default") # "off" to disable if compile_mode != "off": if is_main: print(f"[compile] torch.compile(mode={compile_mode!r}) — first step will be slow") model = torch.compile(model, mode=compile_mode, dynamic=False) optimizer = build_optimizer(config, model, ancestor_table) if is_main: print(f"Model params: {count_parameters(model):,}") print(f"AncestorTable params: {count_parameters(ancestor_table):,}") # ── Resume ──────────────────────────────────────────────────────────────── start_step = 0 if args.resume: ckpt = torch.load(args.resume, map_location=device) raw_model = _unwrap(model) raw_model.load_state_dict(ckpt["model"]) if "ancestor_table" in ckpt: ancestor_table.load_state_dict(ckpt["ancestor_table"]) try: optimizer.load_state_dict(ckpt["optimizer"]) except ValueError: if is_main: print("[WARN] Optimizer state shape mismatch (e.g. tie_weights changed) " "— skipping optimizer resume, restarting optimizer from scratch.") start_step = ckpt["step"] + 1 if is_main: print(f"Resumed from step {start_step}") train_loader, val_loader = build_dataloaders(config, tokenizer) train_iter = iter(train_loader) log_cfg = config.get("logging", {}) save_dir = Path(log_cfg.get("save_dir", "outputs/sad")) if is_main: save_dir.mkdir(parents=True, exist_ok=True) with open(save_dir / "config.yaml", "w") as f: yaml.dump(config, f) use_wandb = is_main and log_cfg.get("use_wandb", False) if use_wandb: try: import wandb wandb.init(project=log_cfg.get("project", "sad"), config=config) except ImportError: use_wandb = False model.train() num_steps = train_cfg["num_steps"] grad_clip = train_cfg.get("grad_clip", 1.0) log_interval = train_cfg.get("log_interval", 100) eval_interval = train_cfg.get("eval_interval", 5000) save_interval = train_cfg.get("save_interval", 10000) last_metrics: dict = {} nan_skips = 0 if is_main and _has_tqdm: pbar = tqdm( total=num_steps, initial=start_step, dynamic_ncols=True, desc="block-ar training", ) else: pbar = None t0 = time.time() for step in range(start_step, num_steps): lr = get_lr(step, config) for pg in optimizer.param_groups: pg["lr"] = lr try: full_batch = next(train_iter) except StopIteration: train_iter = iter(train_loader) full_batch = next(train_iter) batch = { "input_ids": full_batch["input_ids"].to(device, non_blocking=True), "attention_mask": full_batch["attention_mask"].to(device, non_blocking=True), } optimizer.zero_grad() loss, metrics = block_ar_step( batch, model, ancestor_table, loss_fn, noisy_builder, tokenizer, dtype, ) # NaN/Inf guard: occasional bf16 overflow in deep transformer → skip # the bad batch instead of killing a multi-hour run. Must be symmetric # across ranks in DDP (all skip or all proceed) to avoid desync. finite_flag = torch.ones(1, device=device, dtype=torch.int32) if not torch.isfinite(loss): finite_flag.zero_() if world_size > 1: import torch.distributed as dist dist.all_reduce(finite_flag, op=dist.ReduceOp.MIN) if finite_flag.item() == 0: nan_skips += 1 if is_main: print(f"[WARN] step={step} skipped: non-finite loss " f"(total skips={nan_skips})") if use_wandb: import wandb wandb.log({"step": step, "nan_skips": nan_skips}) optimizer.zero_grad(set_to_none=True) if pbar is not None: pbar.update(1) continue loss.backward() # AncestorTable is not wrapped in DDP, so its gradients are NOT # all-reduced automatically. Sync them manually before clip/step, # otherwise each rank's ancestor embeddings drift independently. if world_size > 1: import torch.distributed as dist for p in ancestor_table.parameters(): if p.grad is not None: dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) if grad_clip > 0: nn.utils.clip_grad_norm_( list(model.parameters()) + list(ancestor_table.parameters()), grad_clip, ) optimizer.step() last_metrics = metrics if pbar is not None: l_leaf = metrics.get("loss_leaf", torch.tensor(0.0)) pbar.set_postfix( leaf=fmt_metric(l_leaf), lr=f"{lr:.1e}", ) pbar.update(1) if is_main and step % log_interval == 0: elapsed = time.time() - t0 l_total = metrics.get("loss_total", loss) l_leaf = metrics.get("loss_leaf", torch.tensor(0.0)) l_ancestor = metrics.get("loss_ancestor", torch.tensor(0.0)) print( f"step={step:6d} | " f"total={fmt_metric(l_total)} | " f"leaf={fmt_metric(l_leaf)} | " f"ancestor={fmt_metric(l_ancestor)} | " f"lr={lr:.2e} | " f"t={elapsed:.1f}s" ) if use_wandb: import wandb wandb.log({ "step": step, "lr": lr, **{k: v.item() if hasattr(v, "item") else v for k, v in metrics.items()} }) t0 = time.time() if is_main and step % eval_interval == 0 and step > 0: val_metrics = evaluate( model, ancestor_table, loss_fn, noisy_builder, tokenizer, dtype, val_loader, device, ) print(" VAL | " + " | ".join( f"{k}={v:.4f}" for k, v in val_metrics.items() if k in ("loss_total", "loss_leaf", "loss_ancestor") )) if use_wandb: import wandb wandb.log({"step": step, **{f"val/{k}": v for k, v in val_metrics.items()}}) if is_main and step % save_interval == 0 and step > 0: raw_model = _unwrap(model) save_checkpoint(step, raw_model, ancestor_table, optimizer, config, save_dir, last_metrics) if is_main: raw_model = _unwrap(model) save_checkpoint(num_steps, raw_model, ancestor_table, optimizer, config, save_dir, last_metrics) print("Training complete.") if pbar is not None: pbar.close() if world_size > 1: import torch.distributed as dist dist.destroy_process_group() if __name__ == "__main__": main()