#!/usr/bin/env python3 """ train_block_diffusion.py – Block-form mask diffusion training (no ancestor states). Uses the same SADModel block architecture and forward_vectorized training path as train_sad.py, but the corruption process is binary: - level 0: clean - level 1: mask No AncestorTable is created, and the loss is the binary MDLM/SUBS-style masked token objective over corrupted positions only. Usage: python scripts/train_block_diffusion.py --config configs/block_diffusion_owt_b32.yaml torchrun --nproc_per_node=8 scripts/train_block_diffusion.py --config configs/block_diffusion_owt_b32.yaml torchrun --nproc_per_node=8 scripts/train_block_diffusion.py \ --config configs/block_diffusion_owt_b32.yaml \ --resume outputs/block_diffusion/latest.pt """ import sys import os import argparse import math import time from pathlib import Path ROOT = Path(__file__).resolve().parents[1] # sad/ import torch import torch.nn as nn import yaml sys.path.insert(0, str(ROOT)) from src.utils import set_seed, count_parameters from src.models.sad_model import SADModel from src.diffusion.noisy_state import NoisyStateBuilder from src.losses.sad_loss import SADLoss from src.data import build_debug_dataloader, build_owt_dataloader try: from tqdm import tqdm _has_tqdm = True except ImportError: _has_tqdm = False def _unwrap(model): """Peel DDP (.module) and torch.compile (._orig_mod) wrappers down to SADModel.""" while True: if hasattr(model, "_orig_mod"): model = model._orig_mod elif hasattr(model, "module"): model = model.module else: return model def parse_args(): p = argparse.ArgumentParser() p.add_argument("--config", default="configs/block_diffusion_owt_b32.yaml") p.add_argument("--resume", default=None, type=str) p.add_argument("--num_steps", type=int, default=None, help="Override training.num_steps in config") p.add_argument("--batch_size", type=int, default=None, help="Override training.batch_size (per-GPU) in config") p.add_argument("--local_rank", type=int, default=0) return p.parse_args() def load_config(path: str) -> dict: with open(path) as f: return yaml.safe_load(f) def build_tokenizer(config: dict): data_cfg = config.get("data", {}) dataset = data_cfg.get("dataset", "debug") vocab_size = config["model"]["vocab_size"] if dataset == "debug": class MockTokenizer: def __init__(self, vocab_size, mask_token_id): self.vocab_size = vocab_size self.mask_token_id = mask_token_id self.pad_token_id = 0 self.eos_token_id = 0 self.model_max_length = config["model"]["max_seq_len"] def __len__(self): return self.vocab_size return MockTokenizer(vocab_size, vocab_size - 1) from transformers import AutoTokenizer tok = AutoTokenizer.from_pretrained( ROOT / "tokenizers" / "gpt2", local_files_only=True, ) if tok.eos_token is None: tok.add_special_tokens({"eos_token": "<|endoftext|>"}) if tok.bos_token is None: tok.bos_token = tok.eos_token if tok.pad_token is None: tok.pad_token = tok.eos_token if tok.mask_token_id is None: tok.add_special_tokens({"mask_token": "[MASK]"}) config["model"]["vocab_size"] = len(tok) if "level_sizes" in config["model"] and config["model"]["level_sizes"]: config["model"]["level_sizes"][0] = len(tok) return tok def build_dataloaders(config: dict, tokenizer): data_cfg = config.get("data", {}) dataset = data_cfg.get("dataset", "debug") seq_len = data_cfg.get("seq_len", 512) batch_size = config["training"]["batch_size"] if dataset == "debug": train_loader = build_debug_dataloader( vocab_size=config["model"]["vocab_size"], seq_len=seq_len, batch_size=batch_size, num_samples=512, mask_token_id=tokenizer.mask_token_id, ) val_loader = build_debug_dataloader( vocab_size=config["model"]["vocab_size"], seq_len=seq_len, batch_size=batch_size, num_samples=64, mask_token_id=tokenizer.mask_token_id, ) elif dataset == "openwebtext": mode = data_cfg.get("mode", "subsample") train_loader = build_owt_dataloader( tokenizer, split="train[:-100000]", seq_len=seq_len, batch_size=batch_size, num_workers=data_cfg.get("num_workers", 4), cache_dir=data_cfg.get("cache_dir", None), max_samples=data_cfg.get("max_train_samples", None), mode=mode, ) val_loader = build_owt_dataloader( tokenizer, split="train[-100000:]", seq_len=seq_len, batch_size=batch_size, num_workers=2, cache_dir=data_cfg.get("cache_dir", None), max_samples=data_cfg.get("max_val_samples", 100000), mode=mode, shard_across_ranks=False, ) else: raise ValueError(f"Unknown dataset: {dataset}") return train_loader, val_loader def build_optimizer(config: dict, model: nn.Module): train_cfg = config["training"] betas = tuple(train_cfg.get("adam_betas", (0.9, 0.99))) return torch.optim.AdamW( list(model.parameters()), lr=train_cfg["lr"], weight_decay=train_cfg.get("weight_decay", 0.02), betas=betas, eps=train_cfg.get("adam_eps", 1e-9), fused=True, ) def get_lr(step: int, config: dict) -> float: train_cfg = config["training"] num_steps = train_cfg["num_steps"] warmup = train_cfg.get("warmup_steps", min(2000, num_steps // 100)) lr_min = train_cfg.get("lr_min", train_cfg["lr"] * 0.1) lr_max = train_cfg["lr"] if step < warmup: return lr_max * step / max(warmup, 1) progress = (step - warmup) / max(num_steps - warmup, 1) return lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * progress)) def build_mask_noisy_embeddings( input_ids: torch.Tensor, levels: torch.Tensor, leaf_embeddings: torch.Tensor, mask_embedding: torch.Tensor, ): """Binary corruption: level 0 keeps the leaf embedding, level 1 uses [MASK].""" noisy_embs = leaf_embeddings[input_ids].clone() mask_pos = levels.bool() if mask_pos.any(): noisy_embs[mask_pos] = mask_embedding.to(noisy_embs.dtype) corrupt_mask = mask_pos return noisy_embs, corrupt_mask def sample_binary_levels( noisy_builder: NoisyStateBuilder, batch_size: int, seq_len: int, device: torch.device, t_eps: float, ): """ Sample one t per sequence, then mask each token i.i.d. with probability t. Returns: t: [B] float in [t_eps, 1 - t_eps] levels: [B, S] int64 with values in {0=clean, 1=mask} """ t = noisy_builder.sample_t(batch_size, device=device, eps=t_eps) levels = torch.bernoulli( t[:, None].expand(batch_size, seq_len) ).to(dtype=torch.long) return t, levels def block_mask_step( batch: dict, model, loss_fn: SADLoss, noisy_builder: NoisyStateBuilder, tokenizer, dtype, t_eps: float, ) -> tuple: """One block-mask diffusion training step using forward_vectorized.""" device = batch["input_ids"].device input_ids = batch["input_ids"] # [B, L] attention_mask = batch["attention_mask"] # [B, L] batch_size, seq_len = input_ids.shape autocast_device = "cuda" if device.type == "cuda" else "cpu" raw_model = _unwrap(model) with torch.autocast(device_type=autocast_device, dtype=dtype): leaf_emb = raw_model.get_leaf_embeddings() # [V, d] mask_emb = leaf_emb[tokenizer.mask_token_id] # [d] clean_embs = leaf_emb[input_ids] # [B, L, d] t, levels = sample_binary_levels( noisy_builder, batch_size, seq_len, device=device, t_eps=t_eps, ) noisy_embs, corrupt_mask = build_mask_noisy_embeddings( input_ids, levels, leaf_emb, mask_emb, ) leaf_logits = model( noisy_embs=noisy_embs, clean_embs=clean_embs, attention_mask=attention_mask, ) loss, metrics = loss_fn( leaf_logits=leaf_logits, input_ids=input_ids, levels=levels, attention_mask=attention_mask, t=t, corrupt_mask=corrupt_mask, ) metrics["mean_level"] = levels.float().mean().detach() metrics["mean_t"] = t.float().mean().detach() metrics["logits_absmax"] = leaf_logits.detach().abs().max() return loss, metrics def save_checkpoint(step, model, optimizer, config, save_dir: Path, metrics: dict): save_dir.mkdir(parents=True, exist_ok=True) ckpt = { "step": step, "model": model.state_dict(), "optimizer": optimizer.state_dict(), "config": config, "metrics": {k: v.item() if hasattr(v, "item") else v for k, v in metrics.items()}, } torch.save(ckpt, save_dir / f"ckpt_{step}.pt") torch.save(ckpt, save_dir / "latest.pt") print(f" Saved checkpoint: {save_dir}/ckpt_{step}.pt") @torch.no_grad() def evaluate(model, loss_fn, noisy_builder, tokenizer, dtype, val_loader, device, t_eps: float, num_batches: int = 50) -> dict: model.eval() total_metrics = {} count = 0 for batch in val_loader: if count >= num_batches: break batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} _, metrics = block_mask_step( batch, model, loss_fn, noisy_builder, tokenizer, dtype, t_eps, ) for k, v in metrics.items(): val = v.item() if hasattr(v, "item") else float(v) total_metrics[k] = total_metrics.get(k, 0.0) + val count += 1 model.train() return {k: v / max(count, 1) for k, v in total_metrics.items()} def fmt_metric(v) -> str: v = v.item() if hasattr(v, "item") else float(v) return f"{v:.4f}" def main(): args = parse_args() config = load_config(args.config) if args.num_steps is not None: config["training"]["num_steps"] = args.num_steps if args.batch_size is not None: config["training"]["batch_size"] = args.batch_size local_rank = int(os.environ.get("LOCAL_RANK", args.local_rank)) world_size = int(os.environ.get("WORLD_SIZE", 1)) is_main = (local_rank == 0) if world_size > 1: import torch.distributed as dist dist.init_process_group("nccl") device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(device) elif torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") if is_main: print(f"Device: {device} world_size: {world_size}") train_cfg = config["training"] set_seed(train_cfg.get("seed", 42) + local_rank) dtype_str = train_cfg.get("dtype", "bf16") dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[dtype_str] tokenizer = build_tokenizer(config) assert tokenizer.pad_token_id == tokenizer.eos_token_id, ( f"forward_vectorized flex path assumes pad_token_id == eos_token_id, " f"got pad={tokenizer.pad_token_id}, eos={tokenizer.eos_token_id}." ) model_cfg = config["model"] model = SADModel( vocab_size=model_cfg["vocab_size"], hidden_size=model_cfg["hidden_size"], n_blocks=model_cfg["n_blocks"], n_heads=model_cfg["n_heads"], cond_dim=model_cfg["cond_dim"], max_seq_len=model_cfg["max_seq_len"], block_size=model_cfg.get("block_size", 16), dropout=model_cfg.get("dropout", 0.0), num_levels=model_cfg.get("num_levels", 1), level_sizes=model_cfg.get("level_sizes"), tie_weights=model_cfg.get("tie_weights", False), ).to(device) loss_cfg = config.get("loss", {}) loss_fn = SADLoss( vocab_size=model_cfg["vocab_size"], lambda_ancestor=0.0, ancestor_table=None, mask_only=loss_cfg.get("mask_only", True), use_mdlm=loss_cfg.get("use_mdlm", True), mdlm_masked_sum_over_all_tokens=True, ).to(device) noisy_builder = NoisyStateBuilder( vocab_size=model_cfg["vocab_size"], mask_token_id=tokenizer.mask_token_id, ) if world_size > 1: from torch.nn.parallel import DistributedDataParallel as DDP model = DDP( model, device_ids=[local_rank], static_graph=True, gradient_as_bucket_view=True, ) compile_mode = train_cfg.get("compile", "default") if compile_mode != "off": if is_main: print(f"[compile] torch.compile(mode={compile_mode!r}) — first step will be slow") model = torch.compile(model, mode=compile_mode, dynamic=False) optimizer = build_optimizer(config, model) if is_main: print(f"Model params: {count_parameters(model):,}") start_step = 0 if args.resume: ckpt = torch.load(args.resume, map_location=device) raw_model = _unwrap(model) raw_model.load_state_dict(ckpt["model"]) optimizer.load_state_dict(ckpt["optimizer"]) start_step = ckpt["step"] + 1 if is_main: print(f"Resumed from step {start_step}") train_loader, val_loader = build_dataloaders(config, tokenizer) train_iter = iter(train_loader) log_cfg = config.get("logging", {}) save_dir = Path(log_cfg.get("save_dir", "outputs/block_diffusion")) if is_main: save_dir.mkdir(parents=True, exist_ok=True) with open(save_dir / "config.yaml", "w") as f: yaml.dump(config, f) use_wandb = is_main and log_cfg.get("use_wandb", False) if use_wandb: try: import wandb wandb.init(project=log_cfg.get("project", "block_diffusion"), config=config) except ImportError: use_wandb = False model.train() num_steps = train_cfg["num_steps"] grad_clip = train_cfg.get("grad_clip", 1.0) log_interval = train_cfg.get("log_interval", 100) eval_interval = train_cfg.get("eval_interval", 5000) save_interval = train_cfg.get("save_interval", 10000) t_eps = float(train_cfg.get("t_eps", 1e-3)) last_metrics: dict = {} nan_skips = 0 if is_main and _has_tqdm: pbar = tqdm( total=num_steps, initial=start_step, dynamic_ncols=True, desc="Block-mask diffusion training", ) else: pbar = None t0 = time.time() for step in range(start_step, num_steps): lr = get_lr(step, config) for pg in optimizer.param_groups: pg["lr"] = lr try: full_batch = next(train_iter) except StopIteration: train_iter = iter(train_loader) full_batch = next(train_iter) batch = { "input_ids": full_batch["input_ids"].to(device, non_blocking=True), "attention_mask": full_batch["attention_mask"].to(device, non_blocking=True), } optimizer.zero_grad() loss, metrics = block_mask_step( batch, model, loss_fn, noisy_builder, tokenizer, dtype, t_eps, ) finite_flag = torch.ones(1, device=device, dtype=torch.int32) if not torch.isfinite(loss): finite_flag.zero_() if world_size > 1: import torch.distributed as dist dist.all_reduce(finite_flag, op=dist.ReduceOp.MIN) if finite_flag.item() == 0: nan_skips += 1 if is_main: print(f"[WARN] step={step} skipped: non-finite loss " f"(total skips={nan_skips})") if use_wandb: import wandb wandb.log({"step": step, "nan_skips": nan_skips}) optimizer.zero_grad(set_to_none=True) if pbar is not None: pbar.update(1) continue loss.backward() if grad_clip > 0: nn.utils.clip_grad_norm_(list(model.parameters()), grad_clip) optimizer.step() last_metrics = metrics if pbar is not None: pbar.set_postfix( leaf=fmt_metric(metrics["loss_leaf"]), total=fmt_metric(metrics["loss_total"]), lr=f"{lr:.1e}", ) pbar.update(1) if is_main and step % log_interval == 0: elapsed = time.time() - t0 print( f"step={step:6d} | " f"leaf={fmt_metric(metrics['loss_leaf'])} | " f"total={fmt_metric(metrics['loss_total'])} | " f"t={fmt_metric(metrics['mean_t'])} | " f"mask={fmt_metric(metrics['mean_level'])} | " f"lr={lr:.2e} | " f"t_wall={elapsed:.1f}s" ) if use_wandb: import wandb wandb.log({ "step": step, "lr": lr, **{k: v.item() if hasattr(v, "item") else v for k, v in metrics.items()}, }) t0 = time.time() if is_main and step % eval_interval == 0 and step > 0: val_metrics = evaluate( model, loss_fn, noisy_builder, tokenizer, dtype, val_loader, device, t_eps, ) print(" VAL | " + " | ".join( f"{k}={v:.4f}" for k, v in val_metrics.items() if k in ("loss_leaf", "loss_total", "mean_t", "mean_level") )) if use_wandb: import wandb wandb.log({"step": step, **{f"val/{k}": v for k, v in val_metrics.items()}}) if is_main and step % save_interval == 0 and step > 0: raw_model = _unwrap(model) save_checkpoint(step, raw_model, optimizer, config, save_dir, last_metrics) if is_main: raw_model = _unwrap(model) save_checkpoint(num_steps, raw_model, optimizer, config, save_dir, last_metrics) print("Training complete.") if pbar is not None: pbar.close() if world_size > 1: import torch.distributed as dist dist.destroy_process_group() if __name__ == "__main__": main()