sad / scripts /train_sad.py
haochengsama's picture
Add files using upload-large-folder tool
8b0aeb2 verified
Raw
History Blame Contribute Delete
23.3 kB
#!/usr/bin/env python3
"""
train_sad.py – SAD training script.
SAD training using SADModel.forward_vectorized (block-diff attention mask
+ flex attention when available).
- Each step trains on the full seq_len (no curriculum).
- forward_vectorized: concatenates [noisy|clean], applies block-diff mask.
Usage:
python scripts/train_sad.py --config configs/sad_owt.yaml
torchrun --nproc_per_node=8 scripts/train_sad.py --config configs/sad_owt.yaml
torchrun --nproc_per_node=8 scripts/train_sad.py \\
--config configs/sad_owt.yaml \\
--resume outputs/sad/latest.pt
"""
import sys
import os
import argparse
import math
import time
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1] # sad/
import torch
import torch.nn as nn
import torch.nn.functional as F
import yaml
sys.path.insert(0, str(ROOT))
from src.utils import set_seed, count_parameters, grad_norm
from src.models.sad_model import SADModel
from src.diffusion.ancestor_table import AncestorTable
from src.diffusion.noisy_state import NoisyStateBuilder
from src.losses.sad_loss import SADLoss
from src.data import build_debug_dataloader, build_owt_dataloader
try:
from tqdm import tqdm
_has_tqdm = True
except ImportError:
_has_tqdm = False
def _unwrap(model):
"""Peel DDP (.module) and torch.compile (._orig_mod) wrappers down to SADModel."""
while True:
if hasattr(model, "_orig_mod"):
model = model._orig_mod
elif hasattr(model, "module"):
model = model.module
else:
return model
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--config", default="configs/sad_owt.yaml")
p.add_argument("--resume", default=None, type=str)
p.add_argument("--num_steps", type=int, default=None,
help="Override training.num_steps in config")
p.add_argument("--batch_size", type=int, default=None,
help="Override training.batch_size (per-GPU) in config")
p.add_argument("--local_rank", type=int, default=0)
return p.parse_args()
def load_config(path: str) -> dict:
with open(path) as f:
return yaml.safe_load(f)
def build_tokenizer(config: dict):
data_cfg = config.get("data", {})
dataset = data_cfg.get("dataset", "debug")
vocab_size = config["model"]["vocab_size"]
if dataset == "debug":
class MockTokenizer:
def __init__(self, vocab_size, mask_token_id):
self.vocab_size = vocab_size
self.mask_token_id = mask_token_id
self.pad_token_id = 0
self.eos_token_id = 0 # debug dataset has no real pads; match pad==eos
self.model_max_length = config["model"]["max_seq_len"]
def __len__(self):
return self.vocab_size
return MockTokenizer(vocab_size, vocab_size - 1)
else:
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(
ROOT / "tokenizers" / "gpt2",
local_files_only=True,
)
# 本地 tokenizer_config.json 可能没定义 special tokens;显式登记 <|endoftext|>
if tok.eos_token is None:
tok.add_special_tokens({"eos_token": "<|endoftext|>"})
if tok.bos_token is None:
tok.bos_token = tok.eos_token
if tok.pad_token is None:
tok.pad_token = tok.eos_token
if tok.mask_token_id is None:
tok.add_special_tokens({"mask_token": "[MASK]"})
config["model"]["vocab_size"] = len(tok)
if "level_sizes" in config["model"]:
config["model"]["level_sizes"][0] = len(tok)
return tok
def build_dataloaders(config: dict, tokenizer):
data_cfg = config.get("data", {})
dataset = data_cfg.get("dataset", "debug")
seq_len = data_cfg.get("seq_len", 512)
batch_size = config["training"]["batch_size"]
if dataset == "debug":
train_loader = build_debug_dataloader(
vocab_size=config["model"]["vocab_size"],
seq_len=seq_len,
batch_size=batch_size,
num_samples=512,
mask_token_id=tokenizer.mask_token_id,
)
val_loader = build_debug_dataloader(
vocab_size=config["model"]["vocab_size"],
seq_len=seq_len,
batch_size=batch_size,
num_samples=64,
mask_token_id=tokenizer.mask_token_id,
)
elif dataset == "openwebtext":
mode = data_cfg.get("mode", "subsample")
train_loader = build_owt_dataloader(
tokenizer,
split="train[:-100000]",
seq_len=seq_len,
batch_size=batch_size,
num_workers=data_cfg.get("num_workers", 4),
cache_dir=data_cfg.get("cache_dir", None),
max_samples=data_cfg.get("max_train_samples", None),
mode=mode,
)
val_loader = build_owt_dataloader(
tokenizer,
split="train[-100000:]",
seq_len=seq_len,
batch_size=batch_size,
num_workers=2,
cache_dir=data_cfg.get("cache_dir", None),
max_samples=data_cfg.get("max_val_samples", 100000),
mode=mode,
shard_across_ranks=False, # eval runs on rank 0 only — don't shard
)
else:
raise ValueError(f"Unknown dataset: {dataset}")
return train_loader, val_loader
def build_ancestor_table(config: dict, device, embed_dim: int) -> AncestorTable:
ancestor_cfg = config.get("ancestor", {})
script_dir = ROOT
lut_path = ancestor_cfg.get("lut_path", None)
if lut_path is None:
# Debug mode: generate a random LUT for the configured vocab_size.
# Use an independent Generator seeded from config so every rank sees
# the same LUT — the global RNG has already been perturbed by
# `set_seed(seed + local_rank)` in main().
vocab_size = config["model"]["vocab_size"]
K = ancestor_cfg.get("num_clusters", 64)
top_k = ancestor_cfg.get("top_k", 3)
seed = config.get("training", {}).get("seed", 42)
print(f"[AncestorTable] No lut_path configured – generating random LUT "
f"(V={vocab_size}, K={K}, top_k={top_k}, seed={seed})")
g = torch.Generator().manual_seed(seed)
indices = torch.randint(0, K, (vocab_size, top_k), generator=g)
raw_w = torch.rand(vocab_size, top_k, generator=g)
probs = raw_w / raw_w.sum(dim=-1, keepdim=True)
init_emb = torch.randn(K, embed_dim, generator=g) * 0.02
return AncestorTable(
lut_indices=[indices],
lut_probs=[probs],
init_embeddings=[init_emb],
).to(device)
lut_path = Path(lut_path) if Path(lut_path).is_absolute() else script_dir / lut_path
proto_path = ancestor_cfg.get("proto_path", None)
if proto_path is not None:
proto_path = Path(proto_path) if Path(proto_path).is_absolute() else script_dir / proto_path
table = AncestorTable.from_files(
lut_path=lut_path,
proto_path=proto_path,
embed_dim=embed_dim,
device=device,
)
return table.to(device)
def build_optimizer(config: dict, model: nn.Module, ancestor_table: AncestorTable):
train_cfg = config["training"]
params = list(model.parameters()) + list(ancestor_table.parameters())
betas = tuple(train_cfg.get("adam_betas", (0.9, 0.99)))
return torch.optim.AdamW(
params,
lr=train_cfg["lr"],
weight_decay=train_cfg.get("weight_decay", 0.02),
betas=betas,
eps=train_cfg.get("adam_eps", 1e-9),
fused=True,
)
def get_lr(step: int, config: dict) -> float:
train_cfg = config["training"]
num_steps = train_cfg["num_steps"]
warmup = train_cfg.get("warmup_steps", min(2000, num_steps // 100))
lr_min = train_cfg.get("lr_min", train_cfg["lr"] * 0.1)
lr_max = train_cfg["lr"]
if step < warmup:
return lr_max * step / max(warmup, 1)
progress = (step - warmup) / max(num_steps - warmup, 1)
return lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * progress))
def block_ar_step(
batch: dict,
model,
ancestor_table: AncestorTable,
loss_fn: SADLoss,
noisy_builder: NoisyStateBuilder,
tokenizer,
dtype,
) -> tuple:
"""
One Block-AR training step using forward_vectorized.
Builds clean_embs and noisy_embs from the batch, calls
model.forward_vectorized(noisy_embs, clean_embs), computes SAD loss.
"""
device = batch["input_ids"].device
input_ids = batch["input_ids"] # [B, L]
attention_mask = batch["attention_mask"] # [B, L]
B, L = input_ids.shape
autocast_device = "cuda" if device.type == "cuda" else "cpu"
# DDP 下 get_leaf_embeddings 只是取参数 tensor,不涉及 grad hook;通过 .module 解包即可。
# 正向计算必须走 model(...) 才能触发 DDP 的梯度同步。
raw_model = _unwrap(model)
with torch.autocast(device_type=autocast_device, dtype=dtype):
leaf_emb = raw_model.get_leaf_embeddings() # [V, d]
mask_emb = leaf_emb[tokenizer.mask_token_id] # [d]
clean_embs = leaf_emb[input_ids] # [B, L, d]
# HDLM γ=1 schedule: one t per sequence, per-token 3-state sampling.
t = noisy_builder.sample_t(B, device=device) # [B]
levels = noisy_builder.sample_levels_hdlm(
t, L, num_ancestor_levels=ancestor_table.num_levels,
) # [B, L]
noisy_embs, ancestor_log_probs, ancestor_probs_per_lvl, corrupt_mask = \
noisy_builder.build_noisy_embeddings(
input_ids, levels, ancestor_table, leaf_emb, mask_emb
)
# Always pass attention_mask — branching on `(attention_mask == 0).any()`
# would force a GPU→CPU sync every step. The mask-add cost is negligible.
leaf_logits = model(
noisy_embs=noisy_embs,
clean_embs=clean_embs,
attention_mask=attention_mask,
) # [B, L, V]
loss, metrics = loss_fn(
leaf_logits=leaf_logits,
input_ids=input_ids,
levels=levels,
attention_mask=attention_mask,
t=t,
ancestor_log_probs=ancestor_log_probs,
ancestor_probs_per_level=ancestor_probs_per_lvl,
corrupt_mask=corrupt_mask,
)
metrics["mean_level"] = levels.float().mean().detach()
metrics["mean_t"] = t.float().mean().detach()
metrics["logits_absmax"] = leaf_logits.detach().abs().max()
return loss, metrics
def save_checkpoint(step, model, ancestor_table, optimizer, config,
save_dir: Path, metrics: dict):
save_dir.mkdir(parents=True, exist_ok=True)
ckpt = {
"step": step,
"model": model.state_dict(),
"ancestor_table": ancestor_table.state_dict(),
"optimizer": optimizer.state_dict(),
"config": config,
"metrics": {k: v.item() if hasattr(v, 'item') else v for k, v in metrics.items()},
}
torch.save(ckpt, save_dir / f"ckpt_{step}.pt")
torch.save(ckpt, save_dir / "latest.pt")
print(f" Saved checkpoint: {save_dir}/ckpt_{step}.pt")
@torch.no_grad()
def evaluate(model, ancestor_table, loss_fn, noisy_builder, tokenizer, dtype,
val_loader, device, num_batches: int = 50) -> dict:
model.eval()
total_metrics = {}
count = 0
for batch in val_loader:
if count >= num_batches:
break
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
_, metrics = block_ar_step(
batch, model, ancestor_table, loss_fn, noisy_builder, tokenizer, dtype,
)
for k, v in metrics.items():
val = v.item() if hasattr(v, 'item') else float(v)
total_metrics[k] = total_metrics.get(k, 0.0) + val
count += 1
model.train()
return {k: v / max(count, 1) for k, v in total_metrics.items()}
def fmt_metric(v) -> str:
v = v.item() if hasattr(v, 'item') else float(v)
return f"{v:.4f}"
def main():
args = parse_args()
config = load_config(args.config)
if args.num_steps is not None:
config["training"]["num_steps"] = args.num_steps
if args.batch_size is not None:
config["training"]["batch_size"] = args.batch_size
local_rank = int(os.environ.get("LOCAL_RANK", args.local_rank))
world_size = int(os.environ.get("WORLD_SIZE", 1))
is_main = (local_rank == 0)
if world_size > 1:
import torch.distributed as dist
dist.init_process_group("nccl")
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
elif torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
if is_main:
print(f"Device: {device} world_size: {world_size}")
train_cfg = config["training"]
set_seed(train_cfg.get("seed", 42) + local_rank)
dtype_str = train_cfg.get("dtype", "bf16")
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[dtype_str]
tokenizer = build_tokenizer(config)
# Flex-attention path in SADModel.forward_vectorized ignores the padding
# mask under the assumption pad==eos (so attending to pads is harmless).
# Guard that assumption here so a future pad-token change fails loudly.
assert tokenizer.pad_token_id == tokenizer.eos_token_id, (
f"forward_vectorized flex path assumes pad_token_id == eos_token_id, "
f"got pad={tokenizer.pad_token_id}, eos={tokenizer.eos_token_id}. "
f"See TODO in sad_model.py::forward_vectorized for packing support."
)
model_cfg = config["model"]
model = SADModel(
vocab_size=model_cfg["vocab_size"],
hidden_size=model_cfg["hidden_size"],
n_blocks=model_cfg["n_blocks"],
n_heads=model_cfg["n_heads"],
cond_dim=model_cfg["cond_dim"],
max_seq_len=model_cfg["max_seq_len"],
block_size=model_cfg.get("block_size", 16),
dropout=model_cfg.get("dropout", 0.0),
num_levels=model_cfg.get("num_levels", 2),
level_sizes=model_cfg.get("level_sizes"),
tie_weights=model_cfg.get("tie_weights", False),
).to(device)
ancestor_table = build_ancestor_table(config, device, embed_dim=model_cfg["hidden_size"])
# AncestorTable is not wrapped in DDP (which would auto-broadcast init
# params). Per-rank set_seed() means any random init inside build/from_files
# diverges across ranks. Broadcast rank 0's state so all ranks start from
# the same parameters — grad all-reduce alone cannot undo an init mismatch.
if world_size > 1:
import torch.distributed as dist
for p in ancestor_table.parameters():
dist.broadcast(p.data, src=0)
for b in ancestor_table.buffers():
dist.broadcast(b.data, src=0)
loss_cfg = config.get("loss", {})
loss_fn = SADLoss(
vocab_size=model_cfg["vocab_size"],
lambda_ancestor=loss_cfg.get("lambda_ancestor", 0.0),
ancestor_table=ancestor_table if loss_cfg.get("lambda_ancestor", 0.0) > 0 else None,
mask_only=loss_cfg.get("mask_only", True),
).to(device)
noisy_builder = NoisyStateBuilder(
vocab_size=model_cfg["vocab_size"],
mask_token_id=tokenizer.mask_token_id,
)
if world_size > 1:
from torch.nn.parallel import DistributedDataParallel as DDP
model = DDP(
model,
device_ids=[local_rank],
static_graph=True,
gradient_as_bucket_view=True,
)
# torch.compile for whole-graph kernel fusion. Compiled FlexAttention inside
# DDiTBlockWithMask will be traced as part of the same graph.
compile_mode = train_cfg.get("compile", "default") # "off" to disable
if compile_mode != "off":
if is_main:
print(f"[compile] torch.compile(mode={compile_mode!r}) — first step will be slow")
model = torch.compile(model, mode=compile_mode, dynamic=False)
optimizer = build_optimizer(config, model, ancestor_table)
if is_main:
print(f"Model params: {count_parameters(model):,}")
print(f"AncestorTable params: {count_parameters(ancestor_table):,}")
# ── Resume ────────────────────────────────────────────────────────────────
start_step = 0
if args.resume:
ckpt = torch.load(args.resume, map_location=device)
raw_model = _unwrap(model)
raw_model.load_state_dict(ckpt["model"])
if "ancestor_table" in ckpt:
ancestor_table.load_state_dict(ckpt["ancestor_table"])
try:
optimizer.load_state_dict(ckpt["optimizer"])
except ValueError:
if is_main:
print("[WARN] Optimizer state shape mismatch (e.g. tie_weights changed) "
"— skipping optimizer resume, restarting optimizer from scratch.")
start_step = ckpt["step"] + 1
if is_main:
print(f"Resumed from step {start_step}")
train_loader, val_loader = build_dataloaders(config, tokenizer)
train_iter = iter(train_loader)
log_cfg = config.get("logging", {})
save_dir = Path(log_cfg.get("save_dir", "outputs/sad"))
if is_main:
save_dir.mkdir(parents=True, exist_ok=True)
with open(save_dir / "config.yaml", "w") as f:
yaml.dump(config, f)
use_wandb = is_main and log_cfg.get("use_wandb", False)
if use_wandb:
try:
import wandb
wandb.init(project=log_cfg.get("project", "sad"), config=config)
except ImportError:
use_wandb = False
model.train()
num_steps = train_cfg["num_steps"]
grad_clip = train_cfg.get("grad_clip", 1.0)
log_interval = train_cfg.get("log_interval", 100)
eval_interval = train_cfg.get("eval_interval", 5000)
save_interval = train_cfg.get("save_interval", 10000)
last_metrics: dict = {}
nan_skips = 0
if is_main and _has_tqdm:
pbar = tqdm(
total=num_steps,
initial=start_step,
dynamic_ncols=True,
desc="block-ar training",
)
else:
pbar = None
t0 = time.time()
for step in range(start_step, num_steps):
lr = get_lr(step, config)
for pg in optimizer.param_groups:
pg["lr"] = lr
try:
full_batch = next(train_iter)
except StopIteration:
train_iter = iter(train_loader)
full_batch = next(train_iter)
batch = {
"input_ids": full_batch["input_ids"].to(device, non_blocking=True),
"attention_mask": full_batch["attention_mask"].to(device, non_blocking=True),
}
optimizer.zero_grad()
loss, metrics = block_ar_step(
batch, model, ancestor_table, loss_fn, noisy_builder, tokenizer, dtype,
)
# NaN/Inf guard: occasional bf16 overflow in deep transformer → skip
# the bad batch instead of killing a multi-hour run. Must be symmetric
# across ranks in DDP (all skip or all proceed) to avoid desync.
finite_flag = torch.ones(1, device=device, dtype=torch.int32)
if not torch.isfinite(loss):
finite_flag.zero_()
if world_size > 1:
import torch.distributed as dist
dist.all_reduce(finite_flag, op=dist.ReduceOp.MIN)
if finite_flag.item() == 0:
nan_skips += 1
if is_main:
print(f"[WARN] step={step} skipped: non-finite loss "
f"(total skips={nan_skips})")
if use_wandb:
import wandb
wandb.log({"step": step, "nan_skips": nan_skips})
optimizer.zero_grad(set_to_none=True)
if pbar is not None:
pbar.update(1)
continue
loss.backward()
# AncestorTable is not wrapped in DDP, so its gradients are NOT
# all-reduced automatically. Sync them manually before clip/step,
# otherwise each rank's ancestor embeddings drift independently.
if world_size > 1:
import torch.distributed as dist
for p in ancestor_table.parameters():
if p.grad is not None:
dist.all_reduce(p.grad, op=dist.ReduceOp.AVG)
if grad_clip > 0:
nn.utils.clip_grad_norm_(
list(model.parameters()) + list(ancestor_table.parameters()),
grad_clip,
)
optimizer.step()
last_metrics = metrics
if pbar is not None:
l_leaf = metrics.get("loss_leaf", torch.tensor(0.0))
pbar.set_postfix(
leaf=fmt_metric(l_leaf),
lr=f"{lr:.1e}",
)
pbar.update(1)
if is_main and step % log_interval == 0:
elapsed = time.time() - t0
l_total = metrics.get("loss_total", loss)
l_leaf = metrics.get("loss_leaf", torch.tensor(0.0))
l_ancestor = metrics.get("loss_ancestor", torch.tensor(0.0))
print(
f"step={step:6d} | "
f"total={fmt_metric(l_total)} | "
f"leaf={fmt_metric(l_leaf)} | "
f"ancestor={fmt_metric(l_ancestor)} | "
f"lr={lr:.2e} | "
f"t={elapsed:.1f}s"
)
if use_wandb:
import wandb
wandb.log({
"step": step, "lr": lr,
**{k: v.item() if hasattr(v, "item") else v
for k, v in metrics.items()}
})
t0 = time.time()
if is_main and step % eval_interval == 0 and step > 0:
val_metrics = evaluate(
model, ancestor_table, loss_fn, noisy_builder, tokenizer, dtype,
val_loader, device,
)
print(" VAL | " + " | ".join(
f"{k}={v:.4f}" for k, v in val_metrics.items()
if k in ("loss_total", "loss_leaf", "loss_ancestor")
))
if use_wandb:
import wandb
wandb.log({"step": step,
**{f"val/{k}": v for k, v in val_metrics.items()}})
if is_main and step % save_interval == 0 and step > 0:
raw_model = _unwrap(model)
save_checkpoint(step, raw_model, ancestor_table, optimizer,
config, save_dir, last_metrics)
if is_main:
raw_model = _unwrap(model)
save_checkpoint(num_steps, raw_model, ancestor_table, optimizer,
config, save_dir, last_metrics)
print("Training complete.")
if pbar is not None:
pbar.close()
if world_size > 1:
import torch.distributed as dist
dist.destroy_process_group()
if __name__ == "__main__":
main()