| |
| """ |
| train_block_diffusion.py – Block-form mask diffusion training (no ancestor states). |
| |
| Uses the same SADModel block architecture and forward_vectorized training path as |
| train_sad.py, but the corruption process is binary: |
| - level 0: clean |
| - level 1: mask |
| |
| No AncestorTable is created, and the loss is the binary MDLM/SUBS-style masked |
| token objective over corrupted positions only. |
| |
| Usage: |
| python scripts/train_block_diffusion.py --config configs/block_diffusion_owt_b32.yaml |
| torchrun --nproc_per_node=8 scripts/train_block_diffusion.py --config configs/block_diffusion_owt_b32.yaml |
| torchrun --nproc_per_node=8 scripts/train_block_diffusion.py \ |
| --config configs/block_diffusion_owt_b32.yaml \ |
| --resume outputs/block_diffusion/latest.pt |
| """ |
|
|
| import sys |
| import os |
| import argparse |
| import math |
| import time |
| from pathlib import Path |
|
|
| ROOT = Path(__file__).resolve().parents[1] |
|
|
| import torch |
| import torch.nn as nn |
| import yaml |
|
|
| sys.path.insert(0, str(ROOT)) |
|
|
| from src.utils import set_seed, count_parameters |
| from src.models.sad_model import SADModel |
| from src.diffusion.noisy_state import NoisyStateBuilder |
| from src.losses.sad_loss import SADLoss |
| from src.data import build_debug_dataloader, build_owt_dataloader |
|
|
| try: |
| from tqdm import tqdm |
| _has_tqdm = True |
| except ImportError: |
| _has_tqdm = False |
|
|
|
|
| def _unwrap(model): |
| """Peel DDP (.module) and torch.compile (._orig_mod) wrappers down to SADModel.""" |
| while True: |
| if hasattr(model, "_orig_mod"): |
| model = model._orig_mod |
| elif hasattr(model, "module"): |
| model = model.module |
| else: |
| return model |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--config", default="configs/block_diffusion_owt_b32.yaml") |
| p.add_argument("--resume", default=None, type=str) |
| p.add_argument("--num_steps", type=int, default=None, |
| help="Override training.num_steps in config") |
| p.add_argument("--batch_size", type=int, default=None, |
| help="Override training.batch_size (per-GPU) in config") |
| p.add_argument("--local_rank", type=int, default=0) |
| return p.parse_args() |
|
|
|
|
| def load_config(path: str) -> dict: |
| with open(path) as f: |
| return yaml.safe_load(f) |
|
|
|
|
| def build_tokenizer(config: dict): |
| data_cfg = config.get("data", {}) |
| dataset = data_cfg.get("dataset", "debug") |
| vocab_size = config["model"]["vocab_size"] |
|
|
| if dataset == "debug": |
| class MockTokenizer: |
| def __init__(self, vocab_size, mask_token_id): |
| self.vocab_size = vocab_size |
| self.mask_token_id = mask_token_id |
| self.pad_token_id = 0 |
| self.eos_token_id = 0 |
| self.model_max_length = config["model"]["max_seq_len"] |
|
|
| def __len__(self): |
| return self.vocab_size |
|
|
| return MockTokenizer(vocab_size, vocab_size - 1) |
|
|
| from transformers import AutoTokenizer |
| tok = AutoTokenizer.from_pretrained( |
| ROOT / "tokenizers" / "gpt2", |
| local_files_only=True, |
| ) |
| if tok.eos_token is None: |
| tok.add_special_tokens({"eos_token": "<|endoftext|>"}) |
| if tok.bos_token is None: |
| tok.bos_token = tok.eos_token |
| if tok.pad_token is None: |
| tok.pad_token = tok.eos_token |
| if tok.mask_token_id is None: |
| tok.add_special_tokens({"mask_token": "[MASK]"}) |
| config["model"]["vocab_size"] = len(tok) |
| if "level_sizes" in config["model"] and config["model"]["level_sizes"]: |
| config["model"]["level_sizes"][0] = len(tok) |
| return tok |
|
|
|
|
| def build_dataloaders(config: dict, tokenizer): |
| data_cfg = config.get("data", {}) |
| dataset = data_cfg.get("dataset", "debug") |
| seq_len = data_cfg.get("seq_len", 512) |
| batch_size = config["training"]["batch_size"] |
|
|
| if dataset == "debug": |
| train_loader = build_debug_dataloader( |
| vocab_size=config["model"]["vocab_size"], |
| seq_len=seq_len, |
| batch_size=batch_size, |
| num_samples=512, |
| mask_token_id=tokenizer.mask_token_id, |
| ) |
| val_loader = build_debug_dataloader( |
| vocab_size=config["model"]["vocab_size"], |
| seq_len=seq_len, |
| batch_size=batch_size, |
| num_samples=64, |
| mask_token_id=tokenizer.mask_token_id, |
| ) |
| elif dataset == "openwebtext": |
| mode = data_cfg.get("mode", "subsample") |
| train_loader = build_owt_dataloader( |
| tokenizer, |
| split="train[:-100000]", |
| seq_len=seq_len, |
| batch_size=batch_size, |
| num_workers=data_cfg.get("num_workers", 4), |
| cache_dir=data_cfg.get("cache_dir", None), |
| max_samples=data_cfg.get("max_train_samples", None), |
| mode=mode, |
| ) |
| val_loader = build_owt_dataloader( |
| tokenizer, |
| split="train[-100000:]", |
| seq_len=seq_len, |
| batch_size=batch_size, |
| num_workers=2, |
| cache_dir=data_cfg.get("cache_dir", None), |
| max_samples=data_cfg.get("max_val_samples", 100000), |
| mode=mode, |
| shard_across_ranks=False, |
| ) |
| else: |
| raise ValueError(f"Unknown dataset: {dataset}") |
| return train_loader, val_loader |
|
|
|
|
| def build_optimizer(config: dict, model: nn.Module): |
| train_cfg = config["training"] |
| betas = tuple(train_cfg.get("adam_betas", (0.9, 0.99))) |
| return torch.optim.AdamW( |
| list(model.parameters()), |
| lr=train_cfg["lr"], |
| weight_decay=train_cfg.get("weight_decay", 0.02), |
| betas=betas, |
| eps=train_cfg.get("adam_eps", 1e-9), |
| fused=True, |
| ) |
|
|
|
|
| def get_lr(step: int, config: dict) -> float: |
| train_cfg = config["training"] |
| num_steps = train_cfg["num_steps"] |
| warmup = train_cfg.get("warmup_steps", min(2000, num_steps // 100)) |
| lr_min = train_cfg.get("lr_min", train_cfg["lr"] * 0.1) |
| lr_max = train_cfg["lr"] |
| if step < warmup: |
| return lr_max * step / max(warmup, 1) |
| progress = (step - warmup) / max(num_steps - warmup, 1) |
| return lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * progress)) |
|
|
|
|
| def build_mask_noisy_embeddings( |
| input_ids: torch.Tensor, |
| levels: torch.Tensor, |
| leaf_embeddings: torch.Tensor, |
| mask_embedding: torch.Tensor, |
| ): |
| """Binary corruption: level 0 keeps the leaf embedding, level 1 uses [MASK].""" |
| noisy_embs = leaf_embeddings[input_ids].clone() |
| mask_pos = levels.bool() |
| if mask_pos.any(): |
| noisy_embs[mask_pos] = mask_embedding.to(noisy_embs.dtype) |
| corrupt_mask = mask_pos |
| return noisy_embs, corrupt_mask |
|
|
|
|
| def sample_binary_levels( |
| noisy_builder: NoisyStateBuilder, |
| batch_size: int, |
| seq_len: int, |
| device: torch.device, |
| t_eps: float, |
| ): |
| """ |
| Sample one t per sequence, then mask each token i.i.d. with probability t. |
| |
| Returns: |
| t: [B] float in [t_eps, 1 - t_eps] |
| levels: [B, S] int64 with values in {0=clean, 1=mask} |
| """ |
| t = noisy_builder.sample_t(batch_size, device=device, eps=t_eps) |
| levels = torch.bernoulli( |
| t[:, None].expand(batch_size, seq_len) |
| ).to(dtype=torch.long) |
| return t, levels |
|
|
|
|
| def block_mask_step( |
| batch: dict, |
| model, |
| loss_fn: SADLoss, |
| noisy_builder: NoisyStateBuilder, |
| tokenizer, |
| dtype, |
| t_eps: float, |
| ) -> tuple: |
| """One block-mask diffusion training step using forward_vectorized.""" |
| device = batch["input_ids"].device |
| input_ids = batch["input_ids"] |
| attention_mask = batch["attention_mask"] |
| batch_size, seq_len = input_ids.shape |
|
|
| autocast_device = "cuda" if device.type == "cuda" else "cpu" |
| raw_model = _unwrap(model) |
| with torch.autocast(device_type=autocast_device, dtype=dtype): |
| leaf_emb = raw_model.get_leaf_embeddings() |
| mask_emb = leaf_emb[tokenizer.mask_token_id] |
| clean_embs = leaf_emb[input_ids] |
|
|
| t, levels = sample_binary_levels( |
| noisy_builder, batch_size, seq_len, device=device, t_eps=t_eps, |
| ) |
| noisy_embs, corrupt_mask = build_mask_noisy_embeddings( |
| input_ids, levels, leaf_emb, mask_emb, |
| ) |
|
|
| leaf_logits = model( |
| noisy_embs=noisy_embs, |
| clean_embs=clean_embs, |
| attention_mask=attention_mask, |
| ) |
|
|
| loss, metrics = loss_fn( |
| leaf_logits=leaf_logits, |
| input_ids=input_ids, |
| levels=levels, |
| attention_mask=attention_mask, |
| t=t, |
| corrupt_mask=corrupt_mask, |
| ) |
|
|
| metrics["mean_level"] = levels.float().mean().detach() |
| metrics["mean_t"] = t.float().mean().detach() |
| metrics["logits_absmax"] = leaf_logits.detach().abs().max() |
| return loss, metrics |
|
|
|
|
| def save_checkpoint(step, model, optimizer, config, save_dir: Path, metrics: dict): |
| save_dir.mkdir(parents=True, exist_ok=True) |
| ckpt = { |
| "step": step, |
| "model": model.state_dict(), |
| "optimizer": optimizer.state_dict(), |
| "config": config, |
| "metrics": {k: v.item() if hasattr(v, "item") else v for k, v in metrics.items()}, |
| } |
| torch.save(ckpt, save_dir / f"ckpt_{step}.pt") |
| torch.save(ckpt, save_dir / "latest.pt") |
| print(f" Saved checkpoint: {save_dir}/ckpt_{step}.pt") |
|
|
|
|
| @torch.no_grad() |
| def evaluate(model, loss_fn, noisy_builder, tokenizer, dtype, val_loader, device, |
| t_eps: float, num_batches: int = 50) -> dict: |
| model.eval() |
| total_metrics = {} |
| count = 0 |
| for batch in val_loader: |
| if count >= num_batches: |
| break |
| batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} |
| _, metrics = block_mask_step( |
| batch, model, loss_fn, noisy_builder, tokenizer, dtype, t_eps, |
| ) |
| for k, v in metrics.items(): |
| val = v.item() if hasattr(v, "item") else float(v) |
| total_metrics[k] = total_metrics.get(k, 0.0) + val |
| count += 1 |
| model.train() |
| return {k: v / max(count, 1) for k, v in total_metrics.items()} |
|
|
|
|
| def fmt_metric(v) -> str: |
| v = v.item() if hasattr(v, "item") else float(v) |
| return f"{v:.4f}" |
|
|
|
|
| def main(): |
| args = parse_args() |
| config = load_config(args.config) |
| if args.num_steps is not None: |
| config["training"]["num_steps"] = args.num_steps |
| if args.batch_size is not None: |
| config["training"]["batch_size"] = args.batch_size |
|
|
| local_rank = int(os.environ.get("LOCAL_RANK", args.local_rank)) |
| world_size = int(os.environ.get("WORLD_SIZE", 1)) |
| is_main = (local_rank == 0) |
|
|
| if world_size > 1: |
| import torch.distributed as dist |
| dist.init_process_group("nccl") |
| device = torch.device(f"cuda:{local_rank}") |
| torch.cuda.set_device(device) |
| elif torch.cuda.is_available(): |
| device = torch.device("cuda") |
| else: |
| device = torch.device("cpu") |
|
|
| if is_main: |
| print(f"Device: {device} world_size: {world_size}") |
|
|
| train_cfg = config["training"] |
| set_seed(train_cfg.get("seed", 42) + local_rank) |
|
|
| dtype_str = train_cfg.get("dtype", "bf16") |
| dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[dtype_str] |
|
|
| tokenizer = build_tokenizer(config) |
| assert tokenizer.pad_token_id == tokenizer.eos_token_id, ( |
| f"forward_vectorized flex path assumes pad_token_id == eos_token_id, " |
| f"got pad={tokenizer.pad_token_id}, eos={tokenizer.eos_token_id}." |
| ) |
|
|
| model_cfg = config["model"] |
| model = SADModel( |
| vocab_size=model_cfg["vocab_size"], |
| hidden_size=model_cfg["hidden_size"], |
| n_blocks=model_cfg["n_blocks"], |
| n_heads=model_cfg["n_heads"], |
| cond_dim=model_cfg["cond_dim"], |
| max_seq_len=model_cfg["max_seq_len"], |
| block_size=model_cfg.get("block_size", 16), |
| dropout=model_cfg.get("dropout", 0.0), |
| num_levels=model_cfg.get("num_levels", 1), |
| level_sizes=model_cfg.get("level_sizes"), |
| tie_weights=model_cfg.get("tie_weights", False), |
| ).to(device) |
|
|
| loss_cfg = config.get("loss", {}) |
| loss_fn = SADLoss( |
| vocab_size=model_cfg["vocab_size"], |
| lambda_ancestor=0.0, |
| ancestor_table=None, |
| mask_only=loss_cfg.get("mask_only", True), |
| use_mdlm=loss_cfg.get("use_mdlm", True), |
| mdlm_masked_sum_over_all_tokens=True, |
| ).to(device) |
|
|
| noisy_builder = NoisyStateBuilder( |
| vocab_size=model_cfg["vocab_size"], |
| mask_token_id=tokenizer.mask_token_id, |
| ) |
|
|
| if world_size > 1: |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| model = DDP( |
| model, |
| device_ids=[local_rank], |
| static_graph=True, |
| gradient_as_bucket_view=True, |
| ) |
|
|
| compile_mode = train_cfg.get("compile", "default") |
| if compile_mode != "off": |
| if is_main: |
| print(f"[compile] torch.compile(mode={compile_mode!r}) — first step will be slow") |
| model = torch.compile(model, mode=compile_mode, dynamic=False) |
|
|
| optimizer = build_optimizer(config, model) |
|
|
| if is_main: |
| print(f"Model params: {count_parameters(model):,}") |
|
|
| start_step = 0 |
| if args.resume: |
| ckpt = torch.load(args.resume, map_location=device) |
| raw_model = _unwrap(model) |
| raw_model.load_state_dict(ckpt["model"]) |
| optimizer.load_state_dict(ckpt["optimizer"]) |
| start_step = ckpt["step"] + 1 |
| if is_main: |
| print(f"Resumed from step {start_step}") |
|
|
| train_loader, val_loader = build_dataloaders(config, tokenizer) |
| train_iter = iter(train_loader) |
|
|
| log_cfg = config.get("logging", {}) |
| save_dir = Path(log_cfg.get("save_dir", "outputs/block_diffusion")) |
| if is_main: |
| save_dir.mkdir(parents=True, exist_ok=True) |
| with open(save_dir / "config.yaml", "w") as f: |
| yaml.dump(config, f) |
|
|
| use_wandb = is_main and log_cfg.get("use_wandb", False) |
| if use_wandb: |
| try: |
| import wandb |
| wandb.init(project=log_cfg.get("project", "block_diffusion"), config=config) |
| except ImportError: |
| use_wandb = False |
|
|
| model.train() |
| num_steps = train_cfg["num_steps"] |
| grad_clip = train_cfg.get("grad_clip", 1.0) |
| log_interval = train_cfg.get("log_interval", 100) |
| eval_interval = train_cfg.get("eval_interval", 5000) |
| save_interval = train_cfg.get("save_interval", 10000) |
| t_eps = float(train_cfg.get("t_eps", 1e-3)) |
|
|
| last_metrics: dict = {} |
| nan_skips = 0 |
|
|
| if is_main and _has_tqdm: |
| pbar = tqdm( |
| total=num_steps, |
| initial=start_step, |
| dynamic_ncols=True, |
| desc="Block-mask diffusion training", |
| ) |
| else: |
| pbar = None |
|
|
| t0 = time.time() |
|
|
| for step in range(start_step, num_steps): |
| lr = get_lr(step, config) |
| for pg in optimizer.param_groups: |
| pg["lr"] = lr |
|
|
| try: |
| full_batch = next(train_iter) |
| except StopIteration: |
| train_iter = iter(train_loader) |
| full_batch = next(train_iter) |
|
|
| batch = { |
| "input_ids": full_batch["input_ids"].to(device, non_blocking=True), |
| "attention_mask": full_batch["attention_mask"].to(device, non_blocking=True), |
| } |
|
|
| optimizer.zero_grad() |
| loss, metrics = block_mask_step( |
| batch, model, loss_fn, noisy_builder, tokenizer, dtype, t_eps, |
| ) |
|
|
| finite_flag = torch.ones(1, device=device, dtype=torch.int32) |
| if not torch.isfinite(loss): |
| finite_flag.zero_() |
| if world_size > 1: |
| import torch.distributed as dist |
| dist.all_reduce(finite_flag, op=dist.ReduceOp.MIN) |
| if finite_flag.item() == 0: |
| nan_skips += 1 |
| if is_main: |
| print(f"[WARN] step={step} skipped: non-finite loss " |
| f"(total skips={nan_skips})") |
| if use_wandb: |
| import wandb |
| wandb.log({"step": step, "nan_skips": nan_skips}) |
| optimizer.zero_grad(set_to_none=True) |
| if pbar is not None: |
| pbar.update(1) |
| continue |
|
|
| loss.backward() |
|
|
| if grad_clip > 0: |
| nn.utils.clip_grad_norm_(list(model.parameters()), grad_clip) |
| optimizer.step() |
| last_metrics = metrics |
|
|
| if pbar is not None: |
| pbar.set_postfix( |
| leaf=fmt_metric(metrics["loss_leaf"]), |
| total=fmt_metric(metrics["loss_total"]), |
| lr=f"{lr:.1e}", |
| ) |
| pbar.update(1) |
|
|
| if is_main and step % log_interval == 0: |
| elapsed = time.time() - t0 |
| print( |
| f"step={step:6d} | " |
| f"leaf={fmt_metric(metrics['loss_leaf'])} | " |
| f"total={fmt_metric(metrics['loss_total'])} | " |
| f"t={fmt_metric(metrics['mean_t'])} | " |
| f"mask={fmt_metric(metrics['mean_level'])} | " |
| f"lr={lr:.2e} | " |
| f"t_wall={elapsed:.1f}s" |
| ) |
| if use_wandb: |
| import wandb |
| wandb.log({ |
| "step": step, |
| "lr": lr, |
| **{k: v.item() if hasattr(v, "item") else v for k, v in metrics.items()}, |
| }) |
| t0 = time.time() |
|
|
| if is_main and step % eval_interval == 0 and step > 0: |
| val_metrics = evaluate( |
| model, loss_fn, noisy_builder, tokenizer, dtype, val_loader, device, t_eps, |
| ) |
| print(" VAL | " + " | ".join( |
| f"{k}={v:.4f}" for k, v in val_metrics.items() |
| if k in ("loss_leaf", "loss_total", "mean_t", "mean_level") |
| )) |
| if use_wandb: |
| import wandb |
| wandb.log({"step": step, **{f"val/{k}": v for k, v in val_metrics.items()}}) |
|
|
| if is_main and step % save_interval == 0 and step > 0: |
| raw_model = _unwrap(model) |
| save_checkpoint(step, raw_model, optimizer, config, save_dir, last_metrics) |
|
|
| if is_main: |
| raw_model = _unwrap(model) |
| save_checkpoint(num_steps, raw_model, optimizer, config, save_dir, last_metrics) |
| print("Training complete.") |
| if pbar is not None: |
| pbar.close() |
|
|
| if world_size > 1: |
| import torch.distributed as dist |
| dist.destroy_process_group() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|