sad / scripts /train_block_diffusion.py
haochengsama's picture
Add files using upload-large-folder tool
8b0aeb2 verified
Raw
History Blame Contribute Delete
18.7 kB
#!/usr/bin/env python3
"""
train_block_diffusion.py – Block-form mask diffusion training (no ancestor states).
Uses the same SADModel block architecture and forward_vectorized training path as
train_sad.py, but the corruption process is binary:
- level 0: clean
- level 1: mask
No AncestorTable is created, and the loss is the binary MDLM/SUBS-style masked
token objective over corrupted positions only.
Usage:
python scripts/train_block_diffusion.py --config configs/block_diffusion_owt_b32.yaml
torchrun --nproc_per_node=8 scripts/train_block_diffusion.py --config configs/block_diffusion_owt_b32.yaml
torchrun --nproc_per_node=8 scripts/train_block_diffusion.py \
--config configs/block_diffusion_owt_b32.yaml \
--resume outputs/block_diffusion/latest.pt
"""
import sys
import os
import argparse
import math
import time
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1] # sad/
import torch
import torch.nn as nn
import yaml
sys.path.insert(0, str(ROOT))
from src.utils import set_seed, count_parameters
from src.models.sad_model import SADModel
from src.diffusion.noisy_state import NoisyStateBuilder
from src.losses.sad_loss import SADLoss
from src.data import build_debug_dataloader, build_owt_dataloader
try:
from tqdm import tqdm
_has_tqdm = True
except ImportError:
_has_tqdm = False
def _unwrap(model):
"""Peel DDP (.module) and torch.compile (._orig_mod) wrappers down to SADModel."""
while True:
if hasattr(model, "_orig_mod"):
model = model._orig_mod
elif hasattr(model, "module"):
model = model.module
else:
return model
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--config", default="configs/block_diffusion_owt_b32.yaml")
p.add_argument("--resume", default=None, type=str)
p.add_argument("--num_steps", type=int, default=None,
help="Override training.num_steps in config")
p.add_argument("--batch_size", type=int, default=None,
help="Override training.batch_size (per-GPU) in config")
p.add_argument("--local_rank", type=int, default=0)
return p.parse_args()
def load_config(path: str) -> dict:
with open(path) as f:
return yaml.safe_load(f)
def build_tokenizer(config: dict):
data_cfg = config.get("data", {})
dataset = data_cfg.get("dataset", "debug")
vocab_size = config["model"]["vocab_size"]
if dataset == "debug":
class MockTokenizer:
def __init__(self, vocab_size, mask_token_id):
self.vocab_size = vocab_size
self.mask_token_id = mask_token_id
self.pad_token_id = 0
self.eos_token_id = 0
self.model_max_length = config["model"]["max_seq_len"]
def __len__(self):
return self.vocab_size
return MockTokenizer(vocab_size, vocab_size - 1)
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(
ROOT / "tokenizers" / "gpt2",
local_files_only=True,
)
if tok.eos_token is None:
tok.add_special_tokens({"eos_token": "<|endoftext|>"})
if tok.bos_token is None:
tok.bos_token = tok.eos_token
if tok.pad_token is None:
tok.pad_token = tok.eos_token
if tok.mask_token_id is None:
tok.add_special_tokens({"mask_token": "[MASK]"})
config["model"]["vocab_size"] = len(tok)
if "level_sizes" in config["model"] and config["model"]["level_sizes"]:
config["model"]["level_sizes"][0] = len(tok)
return tok
def build_dataloaders(config: dict, tokenizer):
data_cfg = config.get("data", {})
dataset = data_cfg.get("dataset", "debug")
seq_len = data_cfg.get("seq_len", 512)
batch_size = config["training"]["batch_size"]
if dataset == "debug":
train_loader = build_debug_dataloader(
vocab_size=config["model"]["vocab_size"],
seq_len=seq_len,
batch_size=batch_size,
num_samples=512,
mask_token_id=tokenizer.mask_token_id,
)
val_loader = build_debug_dataloader(
vocab_size=config["model"]["vocab_size"],
seq_len=seq_len,
batch_size=batch_size,
num_samples=64,
mask_token_id=tokenizer.mask_token_id,
)
elif dataset == "openwebtext":
mode = data_cfg.get("mode", "subsample")
train_loader = build_owt_dataloader(
tokenizer,
split="train[:-100000]",
seq_len=seq_len,
batch_size=batch_size,
num_workers=data_cfg.get("num_workers", 4),
cache_dir=data_cfg.get("cache_dir", None),
max_samples=data_cfg.get("max_train_samples", None),
mode=mode,
)
val_loader = build_owt_dataloader(
tokenizer,
split="train[-100000:]",
seq_len=seq_len,
batch_size=batch_size,
num_workers=2,
cache_dir=data_cfg.get("cache_dir", None),
max_samples=data_cfg.get("max_val_samples", 100000),
mode=mode,
shard_across_ranks=False,
)
else:
raise ValueError(f"Unknown dataset: {dataset}")
return train_loader, val_loader
def build_optimizer(config: dict, model: nn.Module):
train_cfg = config["training"]
betas = tuple(train_cfg.get("adam_betas", (0.9, 0.99)))
return torch.optim.AdamW(
list(model.parameters()),
lr=train_cfg["lr"],
weight_decay=train_cfg.get("weight_decay", 0.02),
betas=betas,
eps=train_cfg.get("adam_eps", 1e-9),
fused=True,
)
def get_lr(step: int, config: dict) -> float:
train_cfg = config["training"]
num_steps = train_cfg["num_steps"]
warmup = train_cfg.get("warmup_steps", min(2000, num_steps // 100))
lr_min = train_cfg.get("lr_min", train_cfg["lr"] * 0.1)
lr_max = train_cfg["lr"]
if step < warmup:
return lr_max * step / max(warmup, 1)
progress = (step - warmup) / max(num_steps - warmup, 1)
return lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * progress))
def build_mask_noisy_embeddings(
input_ids: torch.Tensor,
levels: torch.Tensor,
leaf_embeddings: torch.Tensor,
mask_embedding: torch.Tensor,
):
"""Binary corruption: level 0 keeps the leaf embedding, level 1 uses [MASK]."""
noisy_embs = leaf_embeddings[input_ids].clone()
mask_pos = levels.bool()
if mask_pos.any():
noisy_embs[mask_pos] = mask_embedding.to(noisy_embs.dtype)
corrupt_mask = mask_pos
return noisy_embs, corrupt_mask
def sample_binary_levels(
noisy_builder: NoisyStateBuilder,
batch_size: int,
seq_len: int,
device: torch.device,
t_eps: float,
):
"""
Sample one t per sequence, then mask each token i.i.d. with probability t.
Returns:
t: [B] float in [t_eps, 1 - t_eps]
levels: [B, S] int64 with values in {0=clean, 1=mask}
"""
t = noisy_builder.sample_t(batch_size, device=device, eps=t_eps)
levels = torch.bernoulli(
t[:, None].expand(batch_size, seq_len)
).to(dtype=torch.long)
return t, levels
def block_mask_step(
batch: dict,
model,
loss_fn: SADLoss,
noisy_builder: NoisyStateBuilder,
tokenizer,
dtype,
t_eps: float,
) -> tuple:
"""One block-mask diffusion training step using forward_vectorized."""
device = batch["input_ids"].device
input_ids = batch["input_ids"] # [B, L]
attention_mask = batch["attention_mask"] # [B, L]
batch_size, seq_len = input_ids.shape
autocast_device = "cuda" if device.type == "cuda" else "cpu"
raw_model = _unwrap(model)
with torch.autocast(device_type=autocast_device, dtype=dtype):
leaf_emb = raw_model.get_leaf_embeddings() # [V, d]
mask_emb = leaf_emb[tokenizer.mask_token_id] # [d]
clean_embs = leaf_emb[input_ids] # [B, L, d]
t, levels = sample_binary_levels(
noisy_builder, batch_size, seq_len, device=device, t_eps=t_eps,
)
noisy_embs, corrupt_mask = build_mask_noisy_embeddings(
input_ids, levels, leaf_emb, mask_emb,
)
leaf_logits = model(
noisy_embs=noisy_embs,
clean_embs=clean_embs,
attention_mask=attention_mask,
)
loss, metrics = loss_fn(
leaf_logits=leaf_logits,
input_ids=input_ids,
levels=levels,
attention_mask=attention_mask,
t=t,
corrupt_mask=corrupt_mask,
)
metrics["mean_level"] = levels.float().mean().detach()
metrics["mean_t"] = t.float().mean().detach()
metrics["logits_absmax"] = leaf_logits.detach().abs().max()
return loss, metrics
def save_checkpoint(step, model, optimizer, config, save_dir: Path, metrics: dict):
save_dir.mkdir(parents=True, exist_ok=True)
ckpt = {
"step": step,
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"config": config,
"metrics": {k: v.item() if hasattr(v, "item") else v for k, v in metrics.items()},
}
torch.save(ckpt, save_dir / f"ckpt_{step}.pt")
torch.save(ckpt, save_dir / "latest.pt")
print(f" Saved checkpoint: {save_dir}/ckpt_{step}.pt")
@torch.no_grad()
def evaluate(model, loss_fn, noisy_builder, tokenizer, dtype, val_loader, device,
t_eps: float, num_batches: int = 50) -> dict:
model.eval()
total_metrics = {}
count = 0
for batch in val_loader:
if count >= num_batches:
break
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
_, metrics = block_mask_step(
batch, model, loss_fn, noisy_builder, tokenizer, dtype, t_eps,
)
for k, v in metrics.items():
val = v.item() if hasattr(v, "item") else float(v)
total_metrics[k] = total_metrics.get(k, 0.0) + val
count += 1
model.train()
return {k: v / max(count, 1) for k, v in total_metrics.items()}
def fmt_metric(v) -> str:
v = v.item() if hasattr(v, "item") else float(v)
return f"{v:.4f}"
def main():
args = parse_args()
config = load_config(args.config)
if args.num_steps is not None:
config["training"]["num_steps"] = args.num_steps
if args.batch_size is not None:
config["training"]["batch_size"] = args.batch_size
local_rank = int(os.environ.get("LOCAL_RANK", args.local_rank))
world_size = int(os.environ.get("WORLD_SIZE", 1))
is_main = (local_rank == 0)
if world_size > 1:
import torch.distributed as dist
dist.init_process_group("nccl")
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
elif torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
if is_main:
print(f"Device: {device} world_size: {world_size}")
train_cfg = config["training"]
set_seed(train_cfg.get("seed", 42) + local_rank)
dtype_str = train_cfg.get("dtype", "bf16")
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[dtype_str]
tokenizer = build_tokenizer(config)
assert tokenizer.pad_token_id == tokenizer.eos_token_id, (
f"forward_vectorized flex path assumes pad_token_id == eos_token_id, "
f"got pad={tokenizer.pad_token_id}, eos={tokenizer.eos_token_id}."
)
model_cfg = config["model"]
model = SADModel(
vocab_size=model_cfg["vocab_size"],
hidden_size=model_cfg["hidden_size"],
n_blocks=model_cfg["n_blocks"],
n_heads=model_cfg["n_heads"],
cond_dim=model_cfg["cond_dim"],
max_seq_len=model_cfg["max_seq_len"],
block_size=model_cfg.get("block_size", 16),
dropout=model_cfg.get("dropout", 0.0),
num_levels=model_cfg.get("num_levels", 1),
level_sizes=model_cfg.get("level_sizes"),
tie_weights=model_cfg.get("tie_weights", False),
).to(device)
loss_cfg = config.get("loss", {})
loss_fn = SADLoss(
vocab_size=model_cfg["vocab_size"],
lambda_ancestor=0.0,
ancestor_table=None,
mask_only=loss_cfg.get("mask_only", True),
use_mdlm=loss_cfg.get("use_mdlm", True),
mdlm_masked_sum_over_all_tokens=True,
).to(device)
noisy_builder = NoisyStateBuilder(
vocab_size=model_cfg["vocab_size"],
mask_token_id=tokenizer.mask_token_id,
)
if world_size > 1:
from torch.nn.parallel import DistributedDataParallel as DDP
model = DDP(
model,
device_ids=[local_rank],
static_graph=True,
gradient_as_bucket_view=True,
)
compile_mode = train_cfg.get("compile", "default")
if compile_mode != "off":
if is_main:
print(f"[compile] torch.compile(mode={compile_mode!r}) — first step will be slow")
model = torch.compile(model, mode=compile_mode, dynamic=False)
optimizer = build_optimizer(config, model)
if is_main:
print(f"Model params: {count_parameters(model):,}")
start_step = 0
if args.resume:
ckpt = torch.load(args.resume, map_location=device)
raw_model = _unwrap(model)
raw_model.load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])
start_step = ckpt["step"] + 1
if is_main:
print(f"Resumed from step {start_step}")
train_loader, val_loader = build_dataloaders(config, tokenizer)
train_iter = iter(train_loader)
log_cfg = config.get("logging", {})
save_dir = Path(log_cfg.get("save_dir", "outputs/block_diffusion"))
if is_main:
save_dir.mkdir(parents=True, exist_ok=True)
with open(save_dir / "config.yaml", "w") as f:
yaml.dump(config, f)
use_wandb = is_main and log_cfg.get("use_wandb", False)
if use_wandb:
try:
import wandb
wandb.init(project=log_cfg.get("project", "block_diffusion"), config=config)
except ImportError:
use_wandb = False
model.train()
num_steps = train_cfg["num_steps"]
grad_clip = train_cfg.get("grad_clip", 1.0)
log_interval = train_cfg.get("log_interval", 100)
eval_interval = train_cfg.get("eval_interval", 5000)
save_interval = train_cfg.get("save_interval", 10000)
t_eps = float(train_cfg.get("t_eps", 1e-3))
last_metrics: dict = {}
nan_skips = 0
if is_main and _has_tqdm:
pbar = tqdm(
total=num_steps,
initial=start_step,
dynamic_ncols=True,
desc="Block-mask diffusion training",
)
else:
pbar = None
t0 = time.time()
for step in range(start_step, num_steps):
lr = get_lr(step, config)
for pg in optimizer.param_groups:
pg["lr"] = lr
try:
full_batch = next(train_iter)
except StopIteration:
train_iter = iter(train_loader)
full_batch = next(train_iter)
batch = {
"input_ids": full_batch["input_ids"].to(device, non_blocking=True),
"attention_mask": full_batch["attention_mask"].to(device, non_blocking=True),
}
optimizer.zero_grad()
loss, metrics = block_mask_step(
batch, model, loss_fn, noisy_builder, tokenizer, dtype, t_eps,
)
finite_flag = torch.ones(1, device=device, dtype=torch.int32)
if not torch.isfinite(loss):
finite_flag.zero_()
if world_size > 1:
import torch.distributed as dist
dist.all_reduce(finite_flag, op=dist.ReduceOp.MIN)
if finite_flag.item() == 0:
nan_skips += 1
if is_main:
print(f"[WARN] step={step} skipped: non-finite loss "
f"(total skips={nan_skips})")
if use_wandb:
import wandb
wandb.log({"step": step, "nan_skips": nan_skips})
optimizer.zero_grad(set_to_none=True)
if pbar is not None:
pbar.update(1)
continue
loss.backward()
if grad_clip > 0:
nn.utils.clip_grad_norm_(list(model.parameters()), grad_clip)
optimizer.step()
last_metrics = metrics
if pbar is not None:
pbar.set_postfix(
leaf=fmt_metric(metrics["loss_leaf"]),
total=fmt_metric(metrics["loss_total"]),
lr=f"{lr:.1e}",
)
pbar.update(1)
if is_main and step % log_interval == 0:
elapsed = time.time() - t0
print(
f"step={step:6d} | "
f"leaf={fmt_metric(metrics['loss_leaf'])} | "
f"total={fmt_metric(metrics['loss_total'])} | "
f"t={fmt_metric(metrics['mean_t'])} | "
f"mask={fmt_metric(metrics['mean_level'])} | "
f"lr={lr:.2e} | "
f"t_wall={elapsed:.1f}s"
)
if use_wandb:
import wandb
wandb.log({
"step": step,
"lr": lr,
**{k: v.item() if hasattr(v, "item") else v for k, v in metrics.items()},
})
t0 = time.time()
if is_main and step % eval_interval == 0 and step > 0:
val_metrics = evaluate(
model, loss_fn, noisy_builder, tokenizer, dtype, val_loader, device, t_eps,
)
print(" VAL | " + " | ".join(
f"{k}={v:.4f}" for k, v in val_metrics.items()
if k in ("loss_leaf", "loss_total", "mean_t", "mean_level")
))
if use_wandb:
import wandb
wandb.log({"step": step, **{f"val/{k}": v for k, v in val_metrics.items()}})
if is_main and step % save_interval == 0 and step > 0:
raw_model = _unwrap(model)
save_checkpoint(step, raw_model, optimizer, config, save_dir, last_metrics)
if is_main:
raw_model = _unwrap(model)
save_checkpoint(num_steps, raw_model, optimizer, config, save_dir, last_metrics)
print("Training complete.")
if pbar is not None:
pbar.close()
if world_size > 1:
import torch.distributed as dist
dist.destroy_process_group()
if __name__ == "__main__":
main()