#!/usr/bin/env python3 from __future__ import annotations import argparse import inspect import json import math import os import random import time from contextlib import nullcontext from pathlib import Path import torch import torch.distributed as dist import torch.nn.functional as F from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sampler from flowtext_lab.bridges import make_dirichlet_bridge_batch, make_prob_bridge_batch try: from flowtext_lab.bridges import make_gaussian_bridge_batch except ImportError: # Optional legacy bridge; most current runs use dirichlet. make_gaussian_bridge_batch = None from flowtext_lab.data import ( CachedWrappedTextSequenceDataset, EosPadCollator, FixedPadCollator, RecordPadTruncateTextSequenceDataset, ShuffledWrappedStreamingTextSequenceDataset, StreamingTextSequenceDataset, TextSequenceDataset, WrappedStreamingTextSequenceDataset, ) from flowtext_lab.decode import fill_blank_init, soft_residual_decode from flowtext_lab.metrics import ( masked_acc, masked_ce, masked_meanflow_l2, masked_soft_ce, masked_topk_acc, summarize_batch, summarize_t_bin_acc, ) from flowtext_lab.model import EndpointPredictor from flowtext_lab.text_detokenization import infer_detokenizer_name from flowtext_lab.tokenization import BpeTextTokenizer def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser() p.add_argument("--data_path", required=True) p.add_argument("--text_column", default=None) p.add_argument("--txt_record_mode", choices=["auto", "line", "eot"], default="auto") p.add_argument("--detokenizer", default="auto") p.add_argument("--openwebtext_split", choices=["all", "train_minus_100k", "valid_last_100k"], default="all") p.add_argument("--tokenizer_path", required=True) p.add_argument("--save_dir", required=True) p.add_argument("--max_records", type=int, default=0) p.add_argument( "--elf_conditional_hf", action="store_true", help="Load ELF-style seq2seq Arrow datasets with condition_input_ids/input_ids columns.", ) p.add_argument("--eval_data_path", default="") p.add_argument("--dataset_cache_dir", default="") p.add_argument("--max_input_len", type=int, default=64) p.add_argument("--conditional_pad_token", choices=["pad", "eos"], default="eos") p.add_argument("--label_drop_prob", type=float, default=0.0) p.add_argument("--streaming", action="store_true") p.add_argument( "--record_pad_truncate", action="store_true", help=( "ELF-style OWT input pipeline: one record -> one example, no global " "stream packing, truncate to max_len and fixed-pad in the collator." ), ) p.add_argument( "--record_add_eos", action="store_true", help="Append tokenizer EOS/SEP to each record in --record_pad_truncate mode.", ) p.add_argument( "--record_add_special_tokens", action="store_true", help="Let the tokenizer add model special tokens in --record_pad_truncate mode.", ) p.add_argument( "--record_pad_token", choices=["pad", "eos"], default="pad", help="Pad id for --record_pad_truncate. ELF-style uses tokenizer PAD when available.", ) p.add_argument( "--record_shuffle_buffer", type=int, default=10000, help="Bounded online shuffle buffer for --record_pad_truncate; 0 disables it.", ) p.add_argument("--wrap", action="store_true", help="Use wrapped packing: [BOS] + payload + [EOS].") p.add_argument( "--wrap_mode", choices=["stream", "record"], default="stream", help="stream keeps the original fixed-width token stream; record packs whole records and only truncates overlength records.", ) p.add_argument("--wrap_record_buffer_size", type=int, default=200) p.add_argument( "--owt_cached_chunks", action="store_true", help="OWT-only FLM/Duo-style cached chunk pool with epoch-level sampler shuffle.", ) p.add_argument( "--owt_chunk_cache_dir", default="", help="Cache directory for --owt_cached_chunks. Stores meta.json and chunks.i32.bin.", ) p.add_argument("--owt_chunk_cache_rebuild", action="store_true") p.add_argument("--owt_chunk_cache_write_batch", type=int, default=4096) p.add_argument( "--owt_exact_repeat_per_chunk", type=int, default=0, help=( "Pilot-only cached OWT sampler: repeat each cached chunk exactly this " "many times, shuffle once deterministically, then shard across ranks." ), ) p.add_argument( "--online_chunk_shuffle", action="store_true", help="Use an online bounded shuffle buffer over wrapped stream chunks.", ) p.add_argument("--online_chunk_shuffle_buffer", type=int, default=10000) p.add_argument("--max_len", type=int, default=128) p.add_argument("--stride", type=int, default=0) p.add_argument("--batch_size", type=int, default=8) p.add_argument("--num_workers", type=int, default=0) p.add_argument("--dataloader_prefetch_factor", type=int, default=2) p.add_argument("--blocking_data_transfer", action="store_true") p.add_argument("--global_batch_size", type=int, default=0, help="If >0, overrides grad_accum with ceil(global_batch_size / batch_size).") p.add_argument("--grad_accum", type=int, default=1) p.add_argument("--total_steps", type=int, default=1000) p.add_argument("--lr", type=float, default=6e-4) p.add_argument("--weight_decay", type=float, default=0.1) p.add_argument( "--output_weight_decay", type=float, default=-1.0, help="If >=0, use this AdamW weight decay only for the output projection weight.", ) p.add_argument("--adam_beta1", type=float, default=0.9) p.add_argument("--adam_beta2", type=float, default=0.95) p.add_argument("--adam_eps", type=float, default=1e-8) p.add_argument("--optimizer", choices=["adamw", "muon"], default="adamw") p.add_argument("--muon_momentum", type=float, default=0.95) p.add_argument("--muon_ns_steps", type=int, default=5) p.add_argument("--muon_update_scale", type=float, default=1.0) p.add_argument("--ema_decay", type=float, default=0.0) p.add_argument("--ema_start_step", type=int, default=0) p.add_argument("--warmup_steps", type=int, default=2000) p.add_argument("--lr_schedule", choices=["constant_warmup", "cosine"], default="cosine") p.add_argument("--min_lr", type=float, default=6e-5) p.add_argument( "--adamw_param_groups", choices=["nanogpt", "all_decay"], default="nanogpt", help="nanoGPT decays only 2D matmul/embedding tensors; all_decay applies weight decay to every parameter.", ) p.add_argument("--grad_clip", type=float, default=1.0) p.add_argument("--seed", type=int, default=42) p.add_argument("--d_model", type=int, default=384) p.add_argument("--cond_dim", type=int, default=128) p.add_argument("--n_layers", type=int, default=6) p.add_argument("--n_heads", type=int, default=6) p.add_argument("--dim_ff", type=int, default=1536) p.add_argument("--dropout", type=float, default=0.0) p.add_argument( "--output_bias", action=argparse.BooleanOptionalAction, default=False, help="Whether to include a vocabulary bias in the final output projection. Default false to avoid global-frequency basin amplification.", ) p.add_argument( "--norm_type", choices=["rmsnorm", "layernorm"], default="rmsnorm", help="Normalization used by DDiT blocks/final head. Default rmsnorm follows modern ELF/T5-style practice.", ) p.add_argument( "--vocab_size_override", type=int, default=0, help="Debug-only: train with a compact/remapped vocabulary of this size instead of tokenizer.vocab_size.", ) p.add_argument("--model_type", choices=["transformer", "ddit"], default="transformer") p.add_argument("--state_format", choices=["logprob", "prob"], default="logprob") p.add_argument("--bridge", choices=["prob", "dirichlet", "gaussian"], default="prob") p.add_argument("--target_loss", choices=["hard_ce", "soft_ce"], default="soft_ce") p.add_argument("--meanflow_weight", type=float, default=0.0) p.add_argument( "--loss_t_weight_mode", choices=["none", "linear_t", "quadratic_t", "one_minus_t_floor", "drop_low_t"], default="none", help="Per-sample t loss weighting. one_minus_t_floor downweights easy high-t targets.", ) p.add_argument("--loss_t_min_weight", type=float, default=0.0) p.add_argument("--loss_t_drop_below", type=float, default=0.2) p.add_argument( "--prior_center_loss_beta", type=float, default=0.0, help="If >0, train CE on logits - beta * logits(prior_state, t), reducing unconditional prior attraction.", ) p.add_argument("--prior_center_state", choices=["uniform"], default="uniform") p.add_argument( "--rollout_train_prob", type=float, default=0.0, help=( "Probability of replacing the supervised input state with a detached " "self-rollout state before computing the reconstruction loss." ), ) p.add_argument( "--rollout_train_steps", type=int, default=1, help="Number of no-grad self-rollout updates before the supervised loss.", ) p.add_argument( "--rollout_train_infer_steps", type=int, default=64, help="Decode horizon used to set the one-step flowmap gamma during rollout training.", ) p.add_argument( "--rollout_train_temp", type=float, default=1.45, help="Endpoint temperature used in the no-grad rollout-training update.", ) p.add_argument( "--rollout_train_max_gamma", type=float, default=1.0, help="Clamp for the no-grad rollout-training flowmap gamma. Set <=0 to disable the clamp.", ) p.add_argument( "--rollout_train_corrupt_only", action=argparse.BooleanOptionalAction, default=True, help="Only replace corrupted positions with the self-rollout state, preserving clean anchors.", ) p.add_argument( "--rollout_train_samplewise", action=argparse.BooleanOptionalAction, default=False, help=( "Apply rollout stochastically per sample instead of per micro-batch. " "This always computes the rollout forward when prob>0, which keeps " "Tensor Core work steady while preserving the requested rollout mix." ), ) p.add_argument( "--rollout_train_compute_always", action=argparse.BooleanOptionalAction, default=False, help=( "For batchwise rollout, always compute the detached rollout forward " "even when the batch coin decides not to apply it. This preserves " "the old rollout objective while increasing steady Tensor Core work." ), ) p.add_argument( "--rollout_train_debug_return_base", action=argparse.BooleanOptionalAction, default=False, help=( "Debug only: still run the detached rollout forward, but return the " "original bridge state to isolate rollout-forward side effects from " "the rollout-state training distribution." ), ) p.add_argument("--target_prob", type=float, default=0.99) p.add_argument("--min_t", type=float, default=0.0) p.add_argument("--max_t", type=float, default=1.0) p.add_argument( "--t_sampling_mode", choices=["uniform", "power_low", "power_high", "logit_uniform"], default="uniform", help=( "How to sample model/flow time t. power_low uses u**gamma to " "oversample low t; power_high uses 1-(1-u)**gamma to oversample high t; " "logit_uniform samples uniformly in logit(t)." ), ) p.add_argument("--t_sampling_power", type=float, default=1.0) p.add_argument("--t_sampling_eps", type=float, default=1e-4) p.add_argument("--dual_t", action="store_true", help="Use separate model/flow time and corruption/support time.") p.add_argument("--corrupt_t_mode", choices=["same", "independent", "constant"], default="independent") p.add_argument("--corrupt_t_value", type=float, default=0.0) p.add_argument("--corrupt_min_t", type=float, default=None) p.add_argument("--corrupt_max_t", type=float, default=None) p.add_argument( "--prefix_block_prob", type=float, default=0.0, help="Train on grow-context states: expose only prefix plus one active block, corrupting the active block and masking future tokens.", ) p.add_argument("--prefix_block_len", type=int, default=128) p.add_argument("--min_mask_ratio", type=float, default=0.1) p.add_argument("--max_mask_ratio", type=float, default=1.0) p.add_argument( "--mask_ratio_floor_schedule", choices=["none", "one_minus_t"], default="none", help=( "Optional per-sample lower-bound schedule for mask ratio. " "one_minus_t uses max(min_mask_ratio, 1-t), so t=0 is all-mask " "while high-t reverts to the configured min_mask_ratio." ), ) p.add_argument( "--mask_mixture_original_prob", type=float, default=0.0, help="If any mask-mixture prob is >0, probability of using the original uniform mask-ratio sampler.", ) p.add_argument( "--mask_mixture_lowk_prob", type=float, default=0.0, help="If any mask-mixture prob is >0, probability of keeping only K clean tokens and corrupting the rest.", ) p.add_argument( "--mask_mixture_lowcorrupt_prob", type=float, default=0.0, help="If any mask-mixture prob is >0, probability of corrupting only K valid tokens and keeping the rest anchored.", ) p.add_argument( "--mask_mixture_block_prob", type=float, default=0.0, help="If any mask-mixture prob is >0, probability of corrupting one contiguous valid-token block and keeping the rest anchored.", ) p.add_argument( "--mask_mixture_all_prob", type=float, default=0.0, help="If any mask-mixture prob is >0, probability of corrupting every valid token.", ) p.add_argument( "--mask_mixture_lowk_clean_tokens", default="1,2,4,8,16,32,64", help="Comma-separated clean-token counts sampled uniformly for the low-K mixture branch.", ) p.add_argument( "--mask_mixture_lowcorrupt_tokens", type=str, default="1,2,4,8,16,32,64", help="Comma-separated corrupt-token counts sampled uniformly for the low-corrupt-K mixture branch.", ) p.add_argument( "--mask_mixture_block_tokens", type=str, default="64,128", help="Comma-separated contiguous block lengths sampled uniformly for the block-corrupt mixture branch.", ) p.add_argument( "--clean_state_mode", choices=["onehot", "bridge"], default="onehot", help="State for non-corrupted support tokens: exact one-hot gold, or a same-t bridge sample around gold.", ) p.add_argument("--wrong_token_replace_prob", default="0.0") p.add_argument("--wrong_token_schedule", choices=["constant", "hard", "linear_t", "exp_t", "exp_k1"], default="constant") p.add_argument("--wrong_token_exp_k", type=float, default=1.0) p.add_argument("--dirichlet_concentration_min", type=float, default=1.0) p.add_argument("--dirichlet_concentration_max", type=float, default=1024.0) p.add_argument( "--dirichlet_endpoint_mode", choices=["bernoulli_wrong", "dual_t_line", "categorical_dual_t"], default="bernoulli_wrong", ) p.add_argument("--dirichlet_semantic_t_mode", choices=["same", "independent", "constant"], default="same") p.add_argument("--dirichlet_semantic_t_value", type=float, default=0.0) p.add_argument( "--dirichlet_semantic_t_curve", choices=["linear", "logit_power"], default="linear", help=( "Transform semantic t before categorical_dual_t Bernoulli gold sampling. " "logit_power uses p=t^gamma/(t^gamma+(1-t)^gamma)." ), ) p.add_argument( "--dirichlet_semantic_t_power", type=float, default=1.0, help="Gamma for --dirichlet_semantic_t_curve logit_power.", ) p.add_argument( "--endpoint_sequence_random_prob_alpha", type=float, default=0.0, help=( "For categorical_dual_t: with probability alpha*(1-t), force all " "corrupted positions in a sequence to use the random endpoint branch. " "0 disables the sequence-level gate." ), ) p.add_argument( "--categorical_wrong_from_full_vocab", action="store_true", help="For categorical_dual_t only: sample the non-gold endpoint from the full vocab, allowing it to equal gold.", ) p.add_argument( "--categorical_wrong_from_batch_valid_tokens", action="store_true", help="For categorical_dual_t only: sample the non-gold endpoint from valid tokens in the current batch.", ) p.add_argument( "--categorical_wrong_basin_token_ids", default="", help=( "For categorical_dual_t only: comma-separated token ids for a basin-bank wrong endpoint branch. " "Used when any categorical_wrong_*_prob is >0." ), ) p.add_argument("--categorical_wrong_basin_prob", type=float, default=0.0) p.add_argument("--categorical_wrong_unigram_prob", type=float, default=0.0) p.add_argument("--categorical_wrong_uniform_prob", type=float, default=0.0) p.add_argument( "--categorical_wrong_corpus_unigram_path", default="", help=( "Optional torch/numpy file containing OWT corpus unigram counts or probabilities. " "When set, categorical_wrong_unigram_prob samples from this corpus distribution." ), ) p.add_argument( "--categorical_wrong_corpus_unigram_alpha", type=float, default=1.0, help="Exponent applied to corpus unigram counts before normalization; values <1 flatten high-frequency tokens.", ) p.add_argument( "--categorical_wrong_basin_shared_prob", type=float, default=0.0, help="Per-sequence probability of sharing one basin-bank wrong token across all corrupted positions.", ) p.add_argument( "--categorical_wrong_unigram_shared_prob", type=float, default=0.0, help="Per-sequence probability of sharing one corpus/batch-unigram wrong token across all corrupted positions.", ) p.add_argument( "--simplex_bridge_sampler", choices=["dirichlet", "logistic_normal_linear_mean"], default="dirichlet", help="How to sample simplex bridge states after choosing the endpoint.", ) p.add_argument("--logistic_normal_sigma_min", type=float, default=0.18) p.add_argument("--logistic_normal_sigma_max", type=float, default=2.2) p.add_argument("--logistic_normal_tau_min", type=float, default=0.65) p.add_argument("--logistic_normal_tau_max", type=float, default=1.15) p.add_argument("--dirichlet_scale", type=float, default=128.0) p.add_argument("--dirichlet_floor", type=float, default=1e-3) p.add_argument("--eps", type=float, default=1e-8) p.add_argument("--log_every", type=int, default=10) p.add_argument("--eval_every", type=int, default=100) p.add_argument("--save_every", type=int, default=500) p.add_argument("--latest_every", type=int, default=0, help="If >0, also refresh save_dir/latest.pt at this interval.") p.add_argument("--resume_path", default="", help="Optional checkpoint path to resume model/optimizer/scheduler state.") p.add_argument( "--init_model_path", default="", help="Optional checkpoint path used only to initialize model weights; optimizer/scheduler/step start fresh.", ) p.add_argument( "--init_pos_embed_mode", choices=["strict", "repeat", "interpolate"], default="strict", help="When --init_model_path has a different pos_embed length, adapt it to the current max_len.", ) p.add_argument("--infer_steps", type=int, default=64) p.add_argument("--decode_damping", type=float, default=1.0) p.add_argument("--max_gamma", type=float, default=1.0) p.add_argument("--decode_solver", choices=["simple", "flowmap"], default="flowmap") p.add_argument("--noise_init", choices=["uniform", "logistic_normal"], default="logistic_normal") p.add_argument("--bridge_noise_init", choices=["uniform", "logistic_normal"], default="logistic_normal") p.add_argument("--noise_sigma", type=float, default=-1.0) p.add_argument("--demo_mask_ratios", default="0.1,0.5,1.0") p.add_argument("--demo_fill_t", type=float, default=0.0) p.add_argument("--bf16", action="store_true") p.add_argument("--allow_tf32", action=argparse.BooleanOptionalAction, default=True) p.add_argument("--activation_checkpointing", action="store_true") p.add_argument("--activation_checkpoint_interval", type=int, default=1) p.add_argument("--activation_checkpoint_scope", choices=["block", "mlp"], default="block") p.add_argument("--ddp_static_graph", action="store_true") p.add_argument("--ddp_gradient_as_bucket_view", action=argparse.BooleanOptionalAction, default=True) p.add_argument( "--full_train_stats", action="store_true", help="Compute heavy train diagnostics every micro-step. Default computes cheap stats every optimizer step and heavy loss diagnostics on log steps.", ) p.add_argument( "--log_train_stats_each_step", action=argparse.BooleanOptionalAction, default=True, help="Accumulate lightweight train diagnostics from every micro-batch, then report weighted averages over each log window.", ) p.add_argument( "--log_t_bin_stats", action=argparse.BooleanOptionalAction, default=True, help="Log corrupt-token accuracy bucketed by model t, making mixed-t train logs easier to read.", ) p.add_argument( "--log_t_bin_edges", default="0,0.2,0.4,0.6,0.8,1.0", help="Comma-separated t-bin edges used when --log_t_bin_stats is enabled.", ) p.add_argument("--torch_compile", action="store_true") p.add_argument("--compile_mode", choices=["default", "reduce-overhead", "max-autotune"], default="max-autotune") return p.parse_args() def _datasetdict_to_single_split(ds): if hasattr(ds, "keys") and not hasattr(ds, "column_names"): for key in ("train", "validation", "test"): if key in ds: return ds[key] return ds[next(iter(ds.keys()))] return ds def _load_arrow_dir_without_hf_metadata(path: str): """Fallback for ELF datasets saved with newer `List` feature metadata. Some nodes have an older `datasets` build that cannot parse feature type `List`, even though pyarrow can read the Arrow files just fine. We strip only the HuggingFace schema metadata and let `datasets` infer equivalent sequence columns from Arrow. """ import pyarrow as pa import pyarrow.ipc as ipc from datasets import Dataset root = Path(path) arrow_files = sorted(root.glob("data-*.arrow")) if not arrow_files: raise FileNotFoundError(f"No data-*.arrow files found under {path}") tables = [] for arrow_path in arrow_files: with pa.memory_map(str(arrow_path), "r") as source: table = ipc.open_stream(source).read_all() tables.append(table.cast(table.schema.remove_metadata())) table = tables[0] if len(tables) == 1 else pa.concat_tables(tables, promote_options="default") return Dataset(table) def load_elf_conditional_dataset(path: str, max_records: int = 0): from datasets import load_from_disk try: ds = _datasetdict_to_single_split(load_from_disk(path)) except ValueError as exc: if "Feature type 'List' not found" not in str(exc): raise ds = _load_arrow_dir_without_hf_metadata(path) if max_records and max_records > 0: ds = ds.select(range(min(int(max_records), len(ds)))) required = {"condition_input_ids", "input_ids"} missing = sorted(required.difference(ds.column_names)) if missing: raise ValueError(f"ELF conditional dataset {path} missing columns: {missing}") return ds class ELFConditionalCollator: def __init__(self, pad_id: int, max_len: int, max_input_len: int, *, loss_on_pad: bool) -> None: self.pad_id = int(pad_id) self.max_len = int(max_len) self.max_input_len = int(max_input_len) self.loss_on_pad = bool(loss_on_pad) def __call__(self, examples: list[dict]) -> dict[str, torch.Tensor]: ids_rows = [] attn_rows = [] loss_rows = [] cond_rows = [] for ex in examples: cond = [int(x) for x in ex["condition_input_ids"][: self.max_input_len]] target_budget = max(0, self.max_len - len(cond)) target = [int(x) for x in ex["input_ids"][:target_budget]] seq = (cond + target)[: self.max_len] cond_len = min(len(cond), self.max_len) target_len = max(0, len(seq) - cond_len) pad_len = self.max_len - len(seq) if pad_len > 0: seq = seq + [self.pad_id] * pad_len attn = [True] * (cond_len + target_len) + [False] * pad_len cond_mask = [True] * cond_len + [False] * (self.max_len - cond_len) if self.loss_on_pad: loss_mask = [False] * cond_len + [True] * (self.max_len - cond_len) else: loss_mask = [False] * cond_len + [True] * target_len + [False] * pad_len ids_rows.append(seq) attn_rows.append(attn) loss_rows.append(loss_mask) cond_rows.append(cond_mask) return { "ids": torch.tensor(ids_rows, dtype=torch.long), "attn_mask": torch.tensor(attn_rows, dtype=torch.bool), "loss_mask": torch.tensor(loss_rows, dtype=torch.bool), "cond_seq_mask": torch.tensor(cond_rows, dtype=torch.bool), } class ExactRepeatDistributedSampler(Sampler[int]): """Deterministic finite schedule with exact per-item repeat counts.""" def __init__( self, dataset_size: int, repeats_per_item: int, *, seed: int, rank: int = 0, world_size: int = 1, ) -> None: self.dataset_size = int(dataset_size) self.repeats_per_item = int(repeats_per_item) self.seed = int(seed) self.rank = int(rank) self.world_size = int(world_size) if self.dataset_size <= 0: raise ValueError("dataset_size must be positive") if self.repeats_per_item <= 0: raise ValueError("repeats_per_item must be positive") total = self.dataset_size * self.repeats_per_item if total % self.world_size != 0: raise ValueError( f"exact repeat schedule total={total} is not divisible by world_size={self.world_size}" ) self.num_samples = total // self.world_size def __iter__(self): import torch indices = torch.arange(self.dataset_size, dtype=torch.long).repeat_interleave( self.repeats_per_item ) g = torch.Generator() g.manual_seed(self.seed) indices = indices[torch.randperm(indices.numel(), generator=g)] return iter(indices[self.rank :: self.world_size].tolist()) def __len__(self) -> int: return self.num_samples def seed_all(seed: int) -> None: random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def sample_time( batch: int, min_t: float, max_t: float, device: torch.device, mode: str = "uniform", power: float = 1.0, eps: float = 1e-4, ) -> torch.Tensor: u = torch.rand(batch, device=device) if mode == "uniform": z = u elif mode == "power_low": z = u.pow(float(power)) elif mode == "power_high": z = 1.0 - (1.0 - u).pow(float(power)) elif mode == "logit_uniform": lo = torch.logit(torch.tensor(float(eps), device=device)) hi = torch.logit(torch.tensor(1.0 - float(eps), device=device)) z = torch.sigmoid(lo + (hi - lo) * u) else: raise ValueError(f"Unknown t_sampling_mode: {mode}") return min_t + (max_t - min_t) * z def sample_loss_t_weight(args: argparse.Namespace, model_t: torch.Tensor) -> torch.Tensor: if args.loss_t_weight_mode == "none": return torch.ones_like(model_t) t = model_t.clamp(0.0, 1.0) if args.loss_t_weight_mode == "linear_t": weight = t elif args.loss_t_weight_mode == "quadratic_t": weight = t.square() elif args.loss_t_weight_mode == "one_minus_t_floor": weight = 1.0 - t elif args.loss_t_weight_mode == "drop_low_t": weight = torch.where(t < float(args.loss_t_drop_below), torch.zeros_like(t), torch.ones_like(t)) else: raise ValueError(f"Unknown loss_t_weight_mode: {args.loss_t_weight_mode}") if args.loss_t_min_weight > 0: weight = weight.clamp_min(float(args.loss_t_min_weight)) return weight def parse_t_bin_edges(raw: str) -> tuple[float, ...]: vals = tuple(float(x) for x in raw.split(",") if x.strip()) if len(vals) < 2: raise ValueError("--log_t_bin_edges must contain at least two comma-separated values") if any(vals[i] >= vals[i + 1] for i in range(len(vals) - 1)): raise ValueError(f"--log_t_bin_edges must be strictly increasing, got {raw!r}") if vals[0] > 0.0 or vals[-1] < 1.0: raise ValueError(f"--log_t_bin_edges must cover [0, 1], got {raw!r}") return vals def average_stats_window(running: list[dict[str, float]]) -> dict[str, float]: keys: list[str] = [] seen: set[str] = set() for stats in running: for key in stats: if key.endswith("__weight"): continue if key not in seen: seen.add(key) keys.append(key) avg: dict[str, float] = {} for key in keys: weighted_sum = 0.0 weight_sum = 0.0 vals: list[float] = [] weight_key = f"{key}__weight" for stats in running: if key not in stats: continue val = float(stats[key]) if not math.isfinite(val): continue if weight_key in stats: weight = float(stats[weight_key]) if math.isfinite(weight) and weight > 0.0: weighted_sum += val * weight weight_sum += weight else: vals.append(val) if weight_sum > 0.0: avg[key] = weighted_sum / weight_sum elif vals: avg[key] = sum(vals) / len(vals) return avg def weighted_masked_ce(logits: torch.Tensor, ids: torch.Tensor, mask: torch.Tensor, sample_weight: torch.Tensor) -> torch.Tensor: per = torch.nn.functional.cross_entropy(logits.flatten(0, 1), ids.flatten(), reduction="none").view_as(ids) weight = mask.float() * sample_weight.float().view(-1, 1) return (per * weight).sum() / weight.sum().clamp_min(1.0) def weighted_masked_soft_ce( logits: torch.Tensor, target_probs: torch.Tensor, mask: torch.Tensor, sample_weight: torch.Tensor, ) -> torch.Tensor: log_probs = torch.nn.functional.log_softmax(logits, dim=-1) per = -(target_probs.to(dtype=log_probs.dtype) * log_probs).sum(dim=-1).float() weight = mask.float() * sample_weight.float().view(-1, 1) return (per * weight).sum() / weight.sum().clamp_min(1.0) def make_prior_center_state( args: argparse.Namespace, batch: int, length: int, vocab_size: int, device: torch.device, ) -> torch.Tensor: if args.prior_center_state != "uniform": raise ValueError(f"Unknown prior_center_state: {args.prior_center_state}") probs = torch.full((batch, length, vocab_size), 1.0 / vocab_size, dtype=torch.float32, device=device) if args.state_format == "prob": return probs return torch.log(probs.clamp_min(args.eps)) def state_to_probs(args: argparse.Namespace, state: torch.Tensor) -> torch.Tensor: if args.state_format == "prob": probs = state.float() else: probs = state.float().exp() probs = probs.clamp_min(args.eps) return probs / probs.sum(dim=-1, keepdim=True).clamp_min(args.eps) def probs_to_state(args: argparse.Namespace, probs: torch.Tensor) -> torch.Tensor: probs = probs.float().clamp_min(args.eps) probs = probs / probs.sum(dim=-1, keepdim=True).clamp_min(args.eps) if args.state_format == "prob": return probs return torch.log(probs.clamp_min(args.eps)) def zero_param_anchor(model: torch.nn.Module, like: torch.Tensor) -> torch.Tensor: anchor = like.new_zeros(()) for p in model.parameters(): if p.requires_grad: anchor = anchor + p.reshape(-1)[0].float() * 0.0 return anchor @torch.no_grad() def maybe_make_rollout_train_state( args: argparse.Namespace, model: torch.nn.Module, base_state: torch.Tensor, model_t: torch.Tensor, attn_mask: torch.Tensor, corrupt_mask: torch.Tensor, ) -> tuple[torch.Tensor, float, torch.Tensor | None]: prob = float(args.rollout_train_prob) steps = int(args.rollout_train_steps) if prob <= 0.0 or steps <= 0: return base_state, 0.0, None samplewise = bool(getattr(args, "rollout_train_samplewise", False)) if samplewise: apply_sample = torch.rand(base_state.size(0), device=model_t.device) < prob applied = float(apply_sample.float().mean().detach().cpu()) else: apply_batch = prob >= 1.0 or float(torch.rand((), device=model_t.device).item()) < prob if not apply_batch and not bool(getattr(args, "rollout_train_compute_always", False)): return base_state, 0.0, None apply_sample = torch.full((base_state.size(0),), bool(apply_batch), device=model_t.device) applied = 1.0 if apply_batch else 0.0 probs = state_to_probs(args, base_state) original_probs = probs rollout_t = model_t.float().clamp(0.0, 1.0) dt = 1.0 / max(1, int(args.rollout_train_infer_steps)) with torch.no_grad(): for _ in range(steps): logits = model(probs_to_state(args, probs), rollout_t, attn_mask) endpoint = torch.nn.functional.softmax(logits / float(args.rollout_train_temp), dim=-1) del logits gamma = dt / (1.0 - rollout_t).clamp_min(args.eps) if float(args.rollout_train_max_gamma) > 0: gamma = gamma.clamp_max(float(args.rollout_train_max_gamma)) probs = probs + gamma.view(-1, 1, 1) * (endpoint - probs) probs = probs.clamp_min(args.eps) probs = probs / probs.sum(dim=-1, keepdim=True).clamp_min(args.eps) rollout_t = (rollout_t + dt).clamp_max(1.0) if bool(getattr(args, "rollout_train_debug_return_base", False)): return base_state, float(applied), apply_sample.detach() if applied <= 0.0: return base_state, 0.0, apply_sample.detach() if args.rollout_train_corrupt_only: apply_pos = corrupt_mask apply_pos = apply_pos & apply_sample.view(-1, 1) probs = torch.where(apply_pos.unsqueeze(-1), probs, original_probs) else: probs = torch.where(apply_sample.view(-1, 1, 1), probs, original_probs) return probs_to_state(args, probs).detach(), float(applied), apply_sample.detach() @torch.no_grad() def zeropower_via_newtonschulz5(g: torch.Tensor, steps: int = 5) -> torch.Tensor: """Muon Newton-Schulz orthogonalization for matrix-shaped updates.""" orig_shape = g.shape x = g.float().reshape(g.shape[0], -1) transposed = False if x.size(0) > x.size(1): x = x.t() transposed = True x = x / (x.norm() + 1e-7) a, b, c = 3.4445, -4.7750, 2.0315 for _ in range(max(1, int(steps))): xx_t = x @ x.t() x = a * x + (b * xx_t + c * (xx_t @ xx_t)) @ x if transposed: x = x.t() return x.reshape(orig_shape).to(dtype=g.dtype) class Muon(torch.optim.Optimizer): """Minimal Muon optimizer with Adam fallback for 1D params.""" def __init__( self, params, *, lr: float, momentum: float = 0.95, ns_steps: int = 5, adam_betas: tuple[float, float] = (0.9, 0.95), eps: float = 1e-8, weight_decay: float = 0.0, update_scale: float = 1.0, ) -> None: params = [p for p in params if p.requires_grad] muon_params = [p for p in params if p.dim() >= 2] adam_params = [p for p in params if p.dim() < 2] defaults = dict( lr=lr, momentum=momentum, ns_steps=ns_steps, adam_betas=adam_betas, eps=eps, weight_decay=weight_decay, update_scale=update_scale, ) groups = [] if muon_params: groups.append({"params": muon_params, "use_muon": True, **defaults}) if adam_params: groups.append({"params": adam_params, "use_muon": False, **defaults}) super().__init__(groups, defaults) @torch.no_grad() def step(self, closure=None): loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: lr = group["lr"] wd = group["weight_decay"] if group.get("use_muon", False): momentum = group["momentum"] ns_steps = group["ns_steps"] update_scale = group["update_scale"] for p in group["params"]: if p.grad is None: continue g = p.grad if g.is_sparse: raise RuntimeError("Muon does not support sparse gradients") state = self.state[p] if not state: state["momentum_buffer"] = torch.zeros_like(p) buf = state["momentum_buffer"] buf.mul_(momentum).add_(g, alpha=1.0 - momentum) update = zeropower_via_newtonschulz5(buf, ns_steps) if wd: p.mul_(1.0 - lr * wd) p.add_(update, alpha=-lr * update_scale) else: beta1, beta2 = group["adam_betas"] eps = group["eps"] for p in group["params"]: if p.grad is None: continue g = p.grad if g.is_sparse: raise RuntimeError("Adam fallback does not support sparse gradients") state = self.state[p] if not state: state["step"] = 0 state["exp_avg"] = torch.zeros_like(p) state["exp_avg_sq"] = torch.zeros_like(p) exp_avg = state["exp_avg"] exp_avg_sq = state["exp_avg_sq"] state["step"] += 1 step = state["step"] if wd: p.mul_(1.0 - lr * wd) exp_avg.mul_(beta1).add_(g, alpha=1.0 - beta1) exp_avg_sq.mul_(beta2).addcmul_(g, g, value=1.0 - beta2) bias_correction1 = 1.0 - beta1**step bias_correction2 = 1.0 - beta2**step step_size = lr * (bias_correction2**0.5) / bias_correction1 denom = exp_avg_sq.sqrt().add_(eps) p.addcdiv_(exp_avg, denom, value=-step_size) return loss def configure_adamw_optimizer( model: torch.nn.Module, args: argparse.Namespace, device: torch.device, ) -> torch.optim.Optimizer: named_params = [(name, p) for name, p in model.named_parameters() if p.requires_grad] params = [p for _, p in named_params] if args.optimizer == "muon": return Muon( params, lr=args.lr, momentum=args.muon_momentum, ns_steps=args.muon_ns_steps, adam_betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps, weight_decay=args.weight_decay, update_scale=args.muon_update_scale, ) fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters extra_args = {"fused": True} if fused_available and device.type == "cuda" else {} output_decay = float(getattr(args, "output_weight_decay", -1.0)) output_weight_names = ("output_layer.linear.weight", "out_proj.weight") output_params = [ p for name, p in named_params if output_decay >= 0.0 and any(name.endswith(suffix) for suffix in output_weight_names) ] output_param_ids = {id(p) for p in output_params} if args.adamw_param_groups == "all_decay": if output_params: other_params = [p for p in params if id(p) not in output_param_ids] optim_groups = [ {"params": other_params, "weight_decay": args.weight_decay}, {"params": output_params, "weight_decay": output_decay}, ] return torch.optim.AdamW( optim_groups, lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps, **extra_args, ) return torch.optim.AdamW( params, lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps, weight_decay=args.weight_decay, **extra_args, ) # nanoGPT convention: decay matrix/embedding weights, not biases/norm/1D params. decay_params = [p for _, p in named_params if p.dim() >= 2 and id(p) not in output_param_ids] nodecay_params = [p for _, p in named_params if p.dim() < 2 and id(p) not in output_param_ids] optim_groups = [] if decay_params: optim_groups.append({"params": decay_params, "weight_decay": args.weight_decay}) if output_params: optim_groups.append({"params": output_params, "weight_decay": output_decay}) if nodecay_params: optim_groups.append({"params": nodecay_params, "weight_decay": 0.0}) return torch.optim.AdamW( optim_groups, lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps, **extra_args, ) @torch.no_grad() def init_ema_state(model: torch.nn.Module) -> dict[str, torch.Tensor]: return { k: v.detach().clone() for k, v in model.state_dict().items() if torch.is_floating_point(v) } @torch.no_grad() def update_ema_state(ema_state: dict[str, torch.Tensor], model: torch.nn.Module, decay: float) -> None: for k, v in model.state_dict().items(): if k in ema_state: ema_state[k].mul_(decay).add_(v.detach(), alpha=1.0 - decay) def resolve_bridge_force_t(args: argparse.Namespace, model_t: torch.Tensor) -> torch.Tensor | None: if not args.dual_t or args.corrupt_t_mode == "same": return model_t if args.corrupt_t_mode == "constant": return torch.full_like(model_t, float(args.corrupt_t_value)) return None def load_corpus_unigram_probs( path: str, vocab_size: int, alpha: float, device: torch.device, ) -> torch.Tensor | None: if not path: return None p = Path(path) if not p.exists(): raise FileNotFoundError(f"categorical wrong corpus unigram file not found: {p}") if p.suffix == ".npy": import numpy as np data = torch.from_numpy(np.load(p)) else: loaded = torch.load(p, map_location="cpu") if isinstance(loaded, dict): for key in ("probs", "prob", "counts", "count", "unigram_counts"): if key in loaded: loaded = loaded[key] break data = torch.as_tensor(loaded) probs = data.flatten().to(dtype=torch.float32, device="cpu") if probs.numel() < vocab_size: probs = F.pad(probs, (0, vocab_size - probs.numel())) elif probs.numel() > vocab_size: probs = probs[:vocab_size] probs = probs.clamp_min(0.0) if float(alpha) != 1.0: probs = probs.pow(float(alpha)) total = probs.sum() if not bool(torch.isfinite(total).item()) or float(total.item()) <= 0.0: raise ValueError(f"corpus unigram file has no positive finite mass: {p}") return (probs / total).to(device=device) def make_bridge( args: argparse.Namespace, ids: torch.Tensor, attn_mask: torch.Tensor, vocab_size: int, force_t: torch.Tensor | None = None, force_corrupt_mask: torch.Tensor | None = None, categorical_wrong_corpus_unigram_probs: torch.Tensor | None = None, ): corrupt_min_t = args.min_t if args.corrupt_min_t is None else args.corrupt_min_t corrupt_max_t = args.max_t if args.corrupt_max_t is None else args.corrupt_max_t if args.bridge == "prob": return make_prob_bridge_batch( ids=ids, attn_mask=attn_mask, vocab_size=vocab_size, target_prob=args.target_prob, min_t=corrupt_min_t, max_t=corrupt_max_t, min_mask_ratio=args.min_mask_ratio, max_mask_ratio=args.max_mask_ratio, wrong_token_replace_prob=args.wrong_token_replace_prob, wrong_token_schedule=args.wrong_token_schedule, wrong_token_exp_k=args.wrong_token_exp_k, eps=args.eps, state_format=args.state_format, noise_init=args.bridge_noise_init, noise_sigma=args.noise_sigma, force_t=force_t, force_corrupt_mask=force_corrupt_mask, mask_ratio_floor_schedule=args.mask_ratio_floor_schedule, mask_mixture_original_prob=args.mask_mixture_original_prob, mask_mixture_lowk_prob=args.mask_mixture_lowk_prob, mask_mixture_lowcorrupt_prob=args.mask_mixture_lowcorrupt_prob, mask_mixture_block_prob=args.mask_mixture_block_prob, mask_mixture_all_prob=args.mask_mixture_all_prob, mask_mixture_lowk_clean_tokens=args.mask_mixture_lowk_clean_tokens, mask_mixture_lowcorrupt_tokens=args.mask_mixture_lowcorrupt_tokens, mask_mixture_block_tokens=args.mask_mixture_block_tokens, clean_state_mode=args.clean_state_mode, ) if args.bridge == "gaussian": if make_gaussian_bridge_batch is None: raise RuntimeError("bridge=gaussian requested, but make_gaussian_bridge_batch is unavailable") return make_gaussian_bridge_batch( ids=ids, attn_mask=attn_mask, vocab_size=vocab_size, target_prob=args.target_prob, min_t=corrupt_min_t, max_t=corrupt_max_t, eps=args.eps, force_t=force_t, ) return make_dirichlet_bridge_batch( ids=ids, attn_mask=attn_mask, vocab_size=vocab_size, target_prob=args.target_prob, min_t=corrupt_min_t, max_t=corrupt_max_t, min_mask_ratio=args.min_mask_ratio, max_mask_ratio=args.max_mask_ratio, wrong_token_replace_prob=args.wrong_token_replace_prob, wrong_token_schedule=args.wrong_token_schedule, wrong_token_exp_k=args.wrong_token_exp_k, dirichlet_concentration_min=args.dirichlet_concentration_min, dirichlet_concentration_max=args.dirichlet_concentration_max, eps=args.eps, state_format=args.state_format, dirichlet_endpoint_mode=args.dirichlet_endpoint_mode, dirichlet_semantic_t_mode=args.dirichlet_semantic_t_mode, dirichlet_semantic_t_value=args.dirichlet_semantic_t_value, dirichlet_semantic_t_curve=args.dirichlet_semantic_t_curve, dirichlet_semantic_t_power=args.dirichlet_semantic_t_power, endpoint_sequence_random_prob_alpha=args.endpoint_sequence_random_prob_alpha, categorical_wrong_from_full_vocab=args.categorical_wrong_from_full_vocab, categorical_wrong_from_batch_valid_tokens=args.categorical_wrong_from_batch_valid_tokens, categorical_wrong_basin_token_ids=args.categorical_wrong_basin_token_ids, categorical_wrong_basin_prob=args.categorical_wrong_basin_prob, categorical_wrong_unigram_prob=args.categorical_wrong_unigram_prob, categorical_wrong_uniform_prob=args.categorical_wrong_uniform_prob, categorical_wrong_basin_shared_prob=args.categorical_wrong_basin_shared_prob, categorical_wrong_unigram_shared_prob=args.categorical_wrong_unigram_shared_prob, categorical_wrong_corpus_unigram_probs=categorical_wrong_corpus_unigram_probs, simplex_bridge_sampler=args.simplex_bridge_sampler, logistic_normal_sigma_min=args.logistic_normal_sigma_min, logistic_normal_sigma_max=args.logistic_normal_sigma_max, logistic_normal_tau_min=args.logistic_normal_tau_min, logistic_normal_tau_max=args.logistic_normal_tau_max, force_t=force_t, force_corrupt_mask=force_corrupt_mask, mask_ratio_floor_schedule=args.mask_ratio_floor_schedule, mask_mixture_original_prob=args.mask_mixture_original_prob, mask_mixture_lowk_prob=args.mask_mixture_lowk_prob, mask_mixture_lowcorrupt_prob=args.mask_mixture_lowcorrupt_prob, mask_mixture_block_prob=args.mask_mixture_block_prob, mask_mixture_all_prob=args.mask_mixture_all_prob, mask_mixture_lowk_clean_tokens=args.mask_mixture_lowk_clean_tokens, mask_mixture_lowcorrupt_tokens=args.mask_mixture_lowcorrupt_tokens, mask_mixture_block_tokens=args.mask_mixture_block_tokens, clean_state_mode=args.clean_state_mode, return_dense_targets=(args.target_loss == "soft_ce" or args.meanflow_weight > 0), ) def maybe_make_prefix_block_masks( args: argparse.Namespace, attn_mask: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor | None]: prob = float(args.prefix_block_prob) if prob <= 0.0: return attn_mask, None device = attn_mask.device batch, length = attn_mask.shape block_len = max(1, int(args.prefix_block_len)) use_prefix = torch.rand(batch, device=device) < prob if not bool(use_prefix.any().item()): return attn_mask, None effective_attn = attn_mask.clone() force_corrupt = torch.zeros_like(attn_mask, dtype=torch.bool) for b in use_prefix.nonzero(as_tuple=False).flatten().tolist(): valid = attn_mask[b].nonzero(as_tuple=False).flatten() if valid.numel() == 0: continue n_valid = int(valid.numel()) n_blocks = max(1, math.ceil(n_valid / block_len)) block_idx = int(torch.randint(0, n_blocks, (1,), device=device).item()) start = block_idx * block_len end = min(start + block_len, n_valid) force_corrupt[b, valid[start:end]] = True if end < n_valid: effective_attn[b, valid[end:]] = False return effective_attn, force_corrupt def dataloader_perf_kwargs(args: argparse.Namespace, device: torch.device) -> dict: kwargs = { "pin_memory": (device.type == "cuda" and args.num_workers > 0), "persistent_workers": args.num_workers > 0, } if args.num_workers > 0 and args.dataloader_prefetch_factor > 0: kwargs["prefetch_factor"] = int(args.dataloader_prefetch_factor) return kwargs def unwrap_model(model: torch.nn.Module) -> torch.nn.Module: if isinstance(model, DistributedDataParallel): model = model.module return getattr(model, "_orig_mod", model) def rank_zero(rank: int) -> bool: return rank == 0 def save_checkpoint( path: Path, model: torch.nn.Module, optimizer: torch.optim.Optimizer, args: argparse.Namespace, tokenizer: BpeTextTokenizer, step: int, scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, ema_state: dict[str, torch.Tensor] | None = None, ) -> None: path.parent.mkdir(parents=True, exist_ok=True) payload = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "args": vars(args), "step": step, "vocab_size": int(getattr(args, "effective_vocab_size", 0) or tokenizer.vocab_size), "eos_id": tokenizer.eos_id, } if ema_state is not None: model_state = model.state_dict() payload["ema_model"] = { k: ema_state[k].to(dtype=v.dtype, device="cpu") if k in ema_state else v.detach().cpu() for k, v in model_state.items() } if scheduler is not None: payload["scheduler"] = scheduler.state_dict() tmp_path = path.with_suffix(path.suffix + ".tmp") torch.save(payload, tmp_path) tmp_path.replace(path) @torch.no_grad() def run_demo(args, model, tokenizer, batch, device) -> None: model.eval() ids = batch["ids"][:1].to(device) attn_mask = batch["attn_mask"][:1].to(device) target = tokenizer.decode(ids[0].tolist()) print(f"[demo] target: {target[:500]}", flush=True) for ratio in [float(x) for x in args.demo_mask_ratios.split(",") if x.strip()]: init, mask = fill_blank_init( ids, tokenizer.vocab_size, args.target_prob, ratio, args.demo_fill_t, args.eps, noise_mode=args.noise_init, noise_sigma=args.noise_sigma, ) final = soft_residual_decode( model, init, attn_mask, args.infer_steps, args.decode_damping, args.max_gamma, args.eps, solver=args.decode_solver, ) print(f"[demo mask_ratio={ratio:.2f}] init : {tokenizer.decode(init.argmax(-1)[0].tolist())[:500]}", flush=True) print(f"[demo mask_ratio={ratio:.2f}] final: {tokenizer.decode(final.argmax(-1)[0].tolist())[:500]}", flush=True) print(f"[demo mask_ratio={ratio:.2f}] corrupt_acc_init={float(((init.argmax(-1)==ids)&mask).float().sum()/mask.float().sum().clamp_min(1)):.4f} corrupt_acc_final={float(((final.argmax(-1)==ids)&mask).float().sum()/mask.float().sum().clamp_min(1)):.4f}", flush=True) model.train() def main() -> None: args = parse_args() t_bin_edges = parse_t_bin_edges(args.log_t_bin_edges) world_size = int(os.environ.get("WORLD_SIZE", "1")) rank = int(os.environ.get("RANK", "0")) local_rank = int(os.environ.get("LOCAL_RANK", "0")) ddp = world_size > 1 if ddp: if torch.cuda.is_available(): torch.cuda.set_device(local_rank) dist.init_process_group(backend="nccl") if args.global_batch_size > 0: args.grad_accum = max(1, math.ceil(args.global_batch_size / max(1, args.batch_size * world_size))) args.resolved_detokenizer = infer_detokenizer_name(raw_path=args.data_path, explicit=args.detokenizer) seed_all(args.seed + rank) device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") if device.type == "cuda": torch.backends.cuda.matmul.allow_tf32 = bool(args.allow_tf32) torch.backends.cudnn.allow_tf32 = bool(args.allow_tf32) def ddp_barrier() -> None: if not ddp: return if device.type == "cuda": dist.barrier(device_ids=[local_rank]) else: dist.barrier() tokenizer = BpeTextTokenizer.from_file(args.tokenizer_path) model_vocab_size = int(args.vocab_size_override) if int(args.vocab_size_override) > 0 else int(tokenizer.vocab_size) args.effective_vocab_size = model_vocab_size categorical_wrong_corpus_unigram_probs = load_corpus_unigram_probs( args.categorical_wrong_corpus_unigram_path, model_vocab_size, args.categorical_wrong_corpus_unigram_alpha, device, ) if rank_zero(rank) and categorical_wrong_corpus_unigram_probs is not None: top_probs, top_ids = torch.topk(categorical_wrong_corpus_unigram_probs.detach().cpu(), k=min(10, model_vocab_size)) top_summary = ",".join(f"{int(i)}:{float(p):.3e}" for i, p in zip(top_ids, top_probs)) print( "loaded_categorical_wrong_corpus_unigram=" f"{args.categorical_wrong_corpus_unigram_path} " f"alpha={args.categorical_wrong_corpus_unigram_alpha} top={top_summary}", flush=True, ) shuffle_epoch_sampler = None if args.elf_conditional_hf: if args.wrap or args.owt_cached_chunks or args.online_chunk_shuffle or args.streaming or args.record_pad_truncate: raise ValueError("--elf_conditional_hf is mutually exclusive with text streaming/packing modes") if args.conditional_pad_token == "pad": conditional_pad_id = tokenizer.pad_id if tokenizer.pad_id is not None else tokenizer.eos_id loss_on_pad = False else: conditional_pad_id = tokenizer.eos_id loss_on_pad = True dataset = load_elf_conditional_dataset(args.data_path, max_records=args.max_records) sampler = DistributedSampler(dataset, shuffle=True, seed=args.seed) if ddp else RandomSampler(dataset) loader = DataLoader( dataset, batch_size=args.batch_size, sampler=sampler, collate_fn=ELFConditionalCollator( conditional_pad_id, max_len=args.max_len, max_input_len=args.max_input_len, loss_on_pad=loss_on_pad, ), num_workers=args.num_workers, **dataloader_perf_kwargs(args, device), drop_last=True, ) shuffle_epoch_sampler = sampler dataset_size = ( f"elf_conditional_hf:{len(dataset)}" f":max_input_len={args.max_input_len}:pad={conditional_pad_id}:loss_on_pad={int(loss_on_pad)}" ) elif args.record_pad_truncate: if args.wrap or args.owt_cached_chunks or args.online_chunk_shuffle: raise ValueError("--record_pad_truncate is mutually exclusive with --wrap/--owt_cached_chunks/--online_chunk_shuffle") if args.record_pad_token == "pad": record_pad_id = tokenizer.pad_id if tokenizer.pad_id is not None else tokenizer.eos_id else: record_pad_id = tokenizer.eos_id dataset = RecordPadTruncateTextSequenceDataset( args.data_path, tokenizer, max_len=args.max_len, text_column=args.text_column, txt_record_mode=args.txt_record_mode, openwebtext_split=args.openwebtext_split, max_records_per_epoch=args.max_records, detokenizer=args.resolved_detokenizer, add_eos=args.record_add_eos, add_special_tokens=args.record_add_special_tokens, shuffle_buffer_size=args.record_shuffle_buffer, seed=args.seed, epoch=0, ) shuffle_epoch_sampler = dataset loader = DataLoader( dataset, batch_size=args.batch_size, collate_fn=FixedPadCollator(record_pad_id, max_len=args.max_len), num_workers=args.num_workers, **dataloader_perf_kwargs(args, device), drop_last=True, ) dataset_size = ( "record_pad_truncate" f":pad={record_pad_id}:add_eos={int(args.record_add_eos)}" f":add_special={int(args.record_add_special_tokens)}" f":shuffle_buffer={args.record_shuffle_buffer}" ) elif args.owt_cached_chunks: if not args.wrap or args.wrap_mode != "stream": raise ValueError("--owt_cached_chunks requires --wrap --wrap_mode stream") if not args.owt_chunk_cache_dir: raise ValueError("--owt_cached_chunks requires --owt_chunk_cache_dir") if args.openwebtext_split not in {"train_minus_100k", "all"}: raise ValueError("--owt_cached_chunks is intended for OWT train/all splits") if ddp and not rank_zero(rank): ddp_barrier() dataset = CachedWrappedTextSequenceDataset( args.data_path, tokenizer, max_len=args.max_len, cache_dir=args.owt_chunk_cache_dir, text_column=args.text_column, txt_record_mode=args.txt_record_mode, openwebtext_split=args.openwebtext_split, max_records=args.max_records, detokenizer=args.resolved_detokenizer, rebuild=args.owt_chunk_cache_rebuild, build_cache=rank_zero(rank), write_batch_size=args.owt_chunk_cache_write_batch, ) if ddp and rank_zero(rank): ddp_barrier() if args.owt_exact_repeat_per_chunk > 0: sampler = ExactRepeatDistributedSampler( len(dataset), args.owt_exact_repeat_per_chunk, seed=args.seed, rank=rank, world_size=world_size, ) shuffle_epoch_sampler = None else: sampler = DistributedSampler(dataset, shuffle=True, seed=args.seed) if ddp else RandomSampler(dataset) shuffle_epoch_sampler = sampler loader = DataLoader( dataset, batch_size=args.batch_size, sampler=sampler, collate_fn=EosPadCollator(tokenizer.eos_id, max_len=args.max_len), num_workers=args.num_workers, **dataloader_perf_kwargs(args, device), drop_last=False, ) dataset_size = f"owt_cached_chunks:{len(dataset)}" elif args.wrap and args.online_chunk_shuffle: if args.wrap_mode != "stream": raise ValueError("--online_chunk_shuffle currently requires --wrap_mode stream") dataset = ShuffledWrappedStreamingTextSequenceDataset( args.data_path, tokenizer, max_len=args.max_len, text_column=args.text_column, txt_record_mode=args.txt_record_mode, openwebtext_split=args.openwebtext_split, max_records_per_epoch=args.max_records, detokenizer=args.resolved_detokenizer, chunk_buffer_size=args.online_chunk_shuffle_buffer, seed=args.seed, epoch=0, ) shuffle_epoch_sampler = dataset loader = DataLoader( dataset, batch_size=args.batch_size, collate_fn=EosPadCollator(tokenizer.eos_id, max_len=args.max_len), num_workers=args.num_workers, **dataloader_perf_kwargs(args, device), drop_last=False, ) dataset_size = f"wrapped_stream_online_shuffle:{args.online_chunk_shuffle_buffer}" elif args.wrap: dataset = WrappedStreamingTextSequenceDataset( args.data_path, tokenizer, max_len=args.max_len, text_column=args.text_column, txt_record_mode=args.txt_record_mode, openwebtext_split=args.openwebtext_split, max_records_per_epoch=args.max_records, detokenizer=args.resolved_detokenizer, wrap_mode=args.wrap_mode, wrap_record_buffer_size=args.wrap_record_buffer_size, ) loader = DataLoader( dataset, batch_size=args.batch_size, collate_fn=EosPadCollator(tokenizer.eos_id, max_len=args.max_len), num_workers=args.num_workers, **dataloader_perf_kwargs(args, device), drop_last=False, ) dataset_size = f"wrapped_{args.wrap_mode}" elif args.streaming: dataset = StreamingTextSequenceDataset( args.data_path, tokenizer, max_len=args.max_len, text_column=args.text_column, txt_record_mode=args.txt_record_mode, openwebtext_split=args.openwebtext_split, max_records_per_epoch=args.max_records, stride=args.stride, detokenizer=args.resolved_detokenizer, ) loader = DataLoader( dataset, batch_size=args.batch_size, collate_fn=EosPadCollator(tokenizer.eos_id, max_len=args.max_len), num_workers=args.num_workers, **dataloader_perf_kwargs(args, device), drop_last=False, ) dataset_size = "streaming" else: dataset = TextSequenceDataset( args.data_path, tokenizer, max_len=args.max_len, text_column=args.text_column, txt_record_mode=args.txt_record_mode, openwebtext_split=args.openwebtext_split, max_records=args.max_records, stride=args.stride, detokenizer=args.resolved_detokenizer, ) sampler = DistributedSampler(dataset, shuffle=True, seed=args.seed) if ddp else RandomSampler(dataset) loader = DataLoader( dataset, batch_size=args.batch_size, sampler=sampler, collate_fn=EosPadCollator(tokenizer.eos_id, max_len=args.max_len), num_workers=args.num_workers, **dataloader_perf_kwargs(args, device), drop_last=False, ) dataset_size = len(dataset) model = EndpointPredictor( vocab_size=model_vocab_size, max_len=args.max_len, d_model=args.d_model, n_heads=args.n_heads, n_layers=args.n_layers, dim_ff=args.dim_ff, dropout=args.dropout, model_type=args.model_type, cond_dim=args.cond_dim, input_format=args.state_format, output_bias=args.output_bias, norm_type=args.norm_type, ).to(device) if args.activation_checkpointing and hasattr(model, "set_activation_checkpointing"): model.set_activation_checkpointing( True, interval=args.activation_checkpoint_interval, scope=args.activation_checkpoint_scope, ) if args.torch_compile: model = torch.compile(model, mode=args.compile_mode if args.compile_mode != "default" else None) trainable_model = model if ddp: trainable_model = DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, static_graph=bool(args.ddp_static_graph), gradient_as_bucket_view=bool(args.ddp_gradient_as_bucket_view), ) optimizer = configure_adamw_optimizer(unwrap_model(trainable_model), args, device) ema_state = init_ema_state(unwrap_model(trainable_model)) if args.ema_decay > 0 else None def lr_lambda(step: int) -> float: if args.warmup_steps > 0 and step < args.warmup_steps: return max(1e-12, float(step + 1) / float(args.warmup_steps)) if args.lr_schedule == "constant_warmup": return 1.0 progress = (step - args.warmup_steps) / max(1, args.total_steps - args.warmup_steps) min_lr_ratio = float(args.min_lr) / max(float(args.lr), 1e-12) return max(min_lr_ratio, 0.5 * (1.0 + math.cos(math.pi * min(max(progress, 0.0), 1.0)))) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) start_step = 1 if args.init_model_path: if args.resume_path: raise ValueError("--init_model_path is mutually exclusive with --resume_path") ckpt = torch.load(args.init_model_path, map_location=device) init_state = dict(ckpt["model"]) model_state = unwrap_model(trainable_model).state_dict() if "pos_embed" in init_state and "pos_embed" in model_state and init_state["pos_embed"].shape != model_state["pos_embed"].shape: if args.init_pos_embed_mode == "strict": raise ValueError( f"init pos_embed shape {tuple(init_state['pos_embed'].shape)} != " f"model pos_embed shape {tuple(model_state['pos_embed'].shape)}; " "set --init_pos_embed_mode repeat or interpolate to adapt" ) src = init_state["pos_embed"] target_len = int(model_state["pos_embed"].shape[1]) if src.shape[1] > target_len: init_state["pos_embed"] = src[:, :target_len].contiguous() elif args.init_pos_embed_mode == "repeat": reps = math.ceil(target_len / int(src.shape[1])) init_state["pos_embed"] = src.repeat(1, reps, 1)[:, :target_len].contiguous() else: init_state["pos_embed"] = F.interpolate( src.transpose(1, 2), size=target_len, mode="linear", align_corners=True, ).transpose(1, 2).contiguous() if rank_zero(rank): print( f"adapted_init_pos_embed from={tuple(src.shape)} to={tuple(init_state['pos_embed'].shape)} " f"mode={args.init_pos_embed_mode}", flush=True, ) for optional_bias_key in ("output_layer.linear.bias", "out_proj.bias"): if optional_bias_key in init_state and optional_bias_key not in model_state: init_state.pop(optional_bias_key) if rank_zero(rank): print(f"dropped_init_output_bias key={optional_bias_key} because current model has output_bias=False", flush=True) unwrap_model(trainable_model).load_state_dict(init_state) if ema_state is not None and "ema_model" in ckpt: for k, v in ckpt["ema_model"].items(): if k in ema_state: ema_state[k].copy_(v.to(device=ema_state[k].device, dtype=ema_state[k].dtype)) if rank_zero(rank): print(f"initialized_model_from={args.init_model_path} start_step={start_step}", flush=True) if args.resume_path: ckpt = torch.load(args.resume_path, map_location=device) unwrap_model(trainable_model).load_state_dict(ckpt["model"]) optimizer.load_state_dict(ckpt["optimizer"]) if "scheduler" in ckpt: scheduler.load_state_dict(ckpt["scheduler"]) else: scheduler.last_epoch = int(ckpt.get("step", 0)) if ema_state is not None and "ema_model" in ckpt: for k, v in ckpt["ema_model"].items(): if k in ema_state: ema_state[k].copy_(v.to(device=ema_state[k].device, dtype=ema_state[k].dtype)) start_step = int(ckpt.get("step", 0)) + 1 if rank_zero(rank): print(f"resumed_from={args.resume_path} start_step={start_step}", flush=True) save_dir = Path(args.save_dir) if rank_zero(rank): save_dir.mkdir(parents=True, exist_ok=True) (save_dir / "args.json").write_text(json.dumps(vars(args), indent=2), encoding="utf-8") print(json.dumps({ "device": str(device), "rank": rank, "world_size": world_size, "samples": dataset_size, "vocab_size": model_vocab_size, "tokenizer_vocab_size": tokenizer.vocab_size, "save_dir": str(save_dir), "batch_size": args.batch_size, "grad_accum": args.grad_accum, "effective_batch_size": args.batch_size * args.grad_accum * world_size, "global_batch_size": args.global_batch_size, "lr_schedule": args.lr_schedule, "optimizer": args.optimizer, "warmup_steps": args.warmup_steps, "min_lr": args.min_lr, "weight_decay": args.weight_decay, "output_weight_decay": args.output_weight_decay, "adamw_param_groups": args.adamw_param_groups, "adam_beta1": args.adam_beta1, "adam_beta2": args.adam_beta2, "adam_eps": args.adam_eps, "muon_momentum": args.muon_momentum, "muon_ns_steps": args.muon_ns_steps, "muon_update_scale": args.muon_update_scale, "ema_decay": args.ema_decay, "ema_start_step": args.ema_start_step, "model_type": args.model_type, "output_bias": args.output_bias, "norm_type": args.norm_type, "t_sampling_mode": args.t_sampling_mode, "t_sampling_power": args.t_sampling_power, "t_sampling_eps": args.t_sampling_eps, "dual_t": args.dual_t, "corrupt_t_mode": args.corrupt_t_mode, "corrupt_min_t": args.corrupt_min_t, "corrupt_max_t": args.corrupt_max_t, "prefix_block_prob": args.prefix_block_prob, "prefix_block_len": args.prefix_block_len, "mask_ratio_floor_schedule": args.mask_ratio_floor_schedule, "dirichlet_endpoint_mode": args.dirichlet_endpoint_mode, "dirichlet_semantic_t_mode": args.dirichlet_semantic_t_mode, "dirichlet_semantic_t_value": args.dirichlet_semantic_t_value, "dirichlet_semantic_t_curve": args.dirichlet_semantic_t_curve, "dirichlet_semantic_t_power": args.dirichlet_semantic_t_power, "endpoint_sequence_random_prob_alpha": args.endpoint_sequence_random_prob_alpha, "categorical_wrong_from_full_vocab": args.categorical_wrong_from_full_vocab, "categorical_wrong_from_batch_valid_tokens": args.categorical_wrong_from_batch_valid_tokens, "categorical_wrong_basin_token_ids": args.categorical_wrong_basin_token_ids, "categorical_wrong_basin_prob": args.categorical_wrong_basin_prob, "categorical_wrong_unigram_prob": args.categorical_wrong_unigram_prob, "categorical_wrong_uniform_prob": args.categorical_wrong_uniform_prob, "categorical_wrong_corpus_unigram_path": args.categorical_wrong_corpus_unigram_path, "categorical_wrong_corpus_unigram_alpha": args.categorical_wrong_corpus_unigram_alpha, "categorical_wrong_basin_shared_prob": args.categorical_wrong_basin_shared_prob, "categorical_wrong_unigram_shared_prob": args.categorical_wrong_unigram_shared_prob, "mask_mixture_original_prob": args.mask_mixture_original_prob, "mask_mixture_lowk_prob": args.mask_mixture_lowk_prob, "mask_mixture_lowcorrupt_prob": args.mask_mixture_lowcorrupt_prob, "mask_mixture_block_prob": args.mask_mixture_block_prob, "mask_mixture_all_prob": args.mask_mixture_all_prob, "mask_mixture_lowk_clean_tokens": args.mask_mixture_lowk_clean_tokens, "mask_mixture_lowcorrupt_tokens": args.mask_mixture_lowcorrupt_tokens, "mask_mixture_block_tokens": args.mask_mixture_block_tokens, "simplex_bridge_sampler": args.simplex_bridge_sampler, "logistic_normal_sigma_min": args.logistic_normal_sigma_min, "logistic_normal_sigma_max": args.logistic_normal_sigma_max, "logistic_normal_tau_min": args.logistic_normal_tau_min, "logistic_normal_tau_max": args.logistic_normal_tau_max, "torch_compile": args.torch_compile, "compile_mode": args.compile_mode, "state_format": args.state_format, "target_loss": args.target_loss, "meanflow_weight": args.meanflow_weight, "rollout_train_prob": args.rollout_train_prob, "rollout_train_steps": args.rollout_train_steps, "rollout_train_infer_steps": args.rollout_train_infer_steps, "rollout_train_temp": args.rollout_train_temp, "rollout_train_max_gamma": args.rollout_train_max_gamma, "rollout_train_corrupt_only": args.rollout_train_corrupt_only, "rollout_train_samplewise": args.rollout_train_samplewise, "rollout_train_compute_always": args.rollout_train_compute_always, "bridge_noise_init": args.bridge_noise_init, "noise_sigma": args.noise_sigma, "allow_tf32": args.allow_tf32, "activation_checkpointing": args.activation_checkpointing, "activation_checkpoint_interval": args.activation_checkpoint_interval, "activation_checkpoint_scope": args.activation_checkpoint_scope, "ddp_static_graph": args.ddp_static_graph, "ddp_gradient_as_bucket_view": args.ddp_gradient_as_bucket_view, "blocking_data_transfer": args.blocking_data_transfer, "dataloader_prefetch_factor": args.dataloader_prefetch_factor, "full_train_stats": args.full_train_stats, "record_pad_truncate": args.record_pad_truncate, "record_add_eos": args.record_add_eos, "record_add_special_tokens": args.record_add_special_tokens, "record_pad_token": args.record_pad_token, "record_shuffle_buffer": args.record_shuffle_buffer, "wrap": args.wrap, "wrap_mode": args.wrap_mode, "wrap_record_buffer_size": args.wrap_record_buffer_size, "owt_cached_chunks": args.owt_cached_chunks, "owt_chunk_cache_dir": args.owt_chunk_cache_dir, "owt_chunk_cache_rebuild": args.owt_chunk_cache_rebuild, "owt_chunk_cache_write_batch": args.owt_chunk_cache_write_batch, "owt_exact_repeat_per_chunk": args.owt_exact_repeat_per_chunk, "online_chunk_shuffle": args.online_chunk_shuffle, "online_chunk_shuffle_buffer": args.online_chunk_shuffle_buffer, "openwebtext_split": args.openwebtext_split, "detokenizer": args.detokenizer, "resolved_detokenizer": args.resolved_detokenizer, "num_workers": args.num_workers, "latest_every": args.latest_every, "resume_path": args.resume_path, }, indent=2), flush=True) use_amp = args.bf16 and device.type == "cuda" data_epoch = 0 data_iter = iter(loader) running = [] start = time.time() trainable_model.train() optimizer.zero_grad(set_to_none=True) for step in range(start_step, args.total_steps + 1): last_batch = None for micro_step in range(args.grad_accum): try: batch = next(data_iter) except StopIteration: data_epoch += 1 if shuffle_epoch_sampler is not None and hasattr(shuffle_epoch_sampler, "set_epoch"): shuffle_epoch_sampler.set_epoch(data_epoch) data_iter = iter(loader) batch = next(data_iter) last_batch = batch non_blocking = device.type == "cuda" and not args.blocking_data_transfer ids = batch["ids"].to(device, non_blocking=non_blocking) attn_mask = batch["attn_mask"].to(device, non_blocking=non_blocking) loss_mask = batch.get("loss_mask") cond_seq_mask = batch.get("cond_seq_mask") if loss_mask is not None: loss_mask = loss_mask.to(device, non_blocking=non_blocking) if cond_seq_mask is not None: cond_seq_mask = cond_seq_mask.to(device, non_blocking=non_blocking) bridge_input_mask = loss_mask if loss_mask is not None else attn_mask bridge_attn_mask, force_corrupt_mask = maybe_make_prefix_block_masks(args, bridge_input_mask) model_t = sample_time( ids.size(0), args.min_t, args.max_t, device, mode=args.t_sampling_mode, power=args.t_sampling_power, eps=args.t_sampling_eps, ) bridge = make_bridge( args, ids, bridge_attn_mask, model_vocab_size, force_t=resolve_bridge_force_t(args, model_t), force_corrupt_mask=force_corrupt_mask, categorical_wrong_corpus_unigram_probs=categorical_wrong_corpus_unigram_probs, ) if loss_mask is not None: model_attn_mask = attn_mask if cond_seq_mask is not None and float(args.label_drop_prob) > 0.0: drop = torch.rand(ids.size(0), device=device) < float(args.label_drop_prob) model_attn_mask = attn_mask & ~(cond_seq_mask & drop.view(-1, 1)) bridge.attn_mask = model_attn_mask loss_t_weight = sample_loss_t_weight(args, model_t) sync_grad = micro_step == args.grad_accum - 1 sync_context = ( trainable_model.no_sync() if ddp and isinstance(trainable_model, DistributedDataParallel) and not args.ddp_static_graph and not sync_grad else nullcontext() ) with sync_context: grad_enabled_before_rollout = torch.is_grad_enabled() loss_state, rollout_train_applied, rollout_train_sample_mask = maybe_make_rollout_train_state( args, trainable_model, bridge.state, model_t, bridge.attn_mask, bridge.corrupt_mask, ) grad_enabled_after_rollout = torch.is_grad_enabled() with torch.amp.autocast("cuda", enabled=use_amp, dtype=torch.bfloat16): logits = trainable_model(loss_state, model_t, bridge.attn_mask) logits_requires_grad = bool(logits.requires_grad) loss_logits = logits if args.prior_center_loss_beta > 0: prior_state = make_prior_center_state( args, ids.size(0), ids.size(1), model_vocab_size, device, ) with torch.no_grad(): prior_logits = trainable_model(prior_state, model_t, bridge.attn_mask) loss_logits = logits.float() - float(args.prior_center_loss_beta) * prior_logits.float() if args.target_loss == "soft_ce": if args.loss_t_weight_mode == "none": raw_loss = masked_soft_ce(loss_logits, bridge.target_probs, bridge.corrupt_mask) else: raw_loss = weighted_masked_soft_ce(loss_logits, bridge.target_probs, bridge.corrupt_mask, loss_t_weight) else: if args.loss_t_weight_mode == "none": raw_loss = masked_ce(loss_logits, bridge.ids, bridge.corrupt_mask) else: raw_loss = weighted_masked_ce(loss_logits, bridge.ids, bridge.corrupt_mask, loss_t_weight) raw_loss_requires_grad = bool(raw_loss.requires_grad) if args.meanflow_weight > 0: if args.state_format == "prob": current_probs = loss_state else: current_probs = loss_state.float().exp() meanflow_loss = masked_meanflow_l2( logits, current_probs=current_probs, target_probs=bridge.target_probs, mask=bridge.corrupt_mask, ) total_loss = raw_loss + float(args.meanflow_weight) * meanflow_loss else: meanflow_loss = raw_loss.new_zeros(()) total_loss = raw_loss if float(args.rollout_train_prob) > 0.0: total_loss = total_loss + zero_param_anchor(unwrap_model(trainable_model), total_loss) loss = total_loss / max(1, args.grad_accum) loss.backward() last_micro_step = micro_step == args.grad_accum - 1 heavy_stats = args.full_train_stats or (step % args.log_every == 0 and last_micro_step) collect_stats = rank_zero(rank) and ( args.full_train_stats or args.log_train_stats_each_step or heavy_stats ) if collect_stats: # Keep window logging cheap: aggregate accuracy/count diagnostics # from every micro-batch, while expensive extras stay log-step only. def scalar(x: torch.Tensor) -> float: return float(x.detach().float().cpu()) batch_weight = float(ids.size(0)) corrupt_count = bridge.corrupt_mask.float().sum() stats = { "loss": float(total_loss.detach().cpu()), "loss__weight": scalar(corrupt_count), "loss_recon": float(raw_loss.detach().cpu()), "loss_recon__weight": scalar(corrupt_count), "loss_meanflow": float(meanflow_loss.detach().cpu()), "loss_meanflow__weight": scalar(corrupt_count), "mean_model_t": float(model_t.mean().detach().cpu()), "mean_model_t__weight": batch_weight, "mean_corrupt_t": float(bridge.t.mean().detach().cpu()), "mean_corrupt_t__weight": batch_weight, "mean_loss_t_weight": float(loss_t_weight.mean().detach().cpu()), "mean_loss_t_weight__weight": batch_weight, "prior_center_loss_beta": float(args.prior_center_loss_beta), "rollout_train_applied": float(rollout_train_applied), "rollout_train_applied__weight": batch_weight, "grad_enabled_before_rollout": float(grad_enabled_before_rollout), "grad_enabled_after_rollout": float(grad_enabled_after_rollout), "logits_requires_grad": float(logits_requires_grad), "raw_loss_requires_grad": float(raw_loss_requires_grad), } with torch.no_grad(): pred = loss_logits.detach().argmax(dim=-1) attn_mask_f = bridge.attn_mask.float() corrupt_mask_f = bridge.corrupt_mask.float() attn_count = attn_mask_f.sum() correct_all = ((pred == bridge.ids) & bridge.attn_mask).float().sum() correct_corrupt = ((pred == bridge.ids) & bridge.corrupt_mask).float().sum() if scalar(attn_count) > 0.0: stats.update({ "acc_all": scalar(correct_all / attn_count.clamp_min(1.0)), "acc_all__weight": scalar(attn_count), "corrupt_frac": scalar(corrupt_count / attn_count.clamp_min(1.0)), "corrupt_frac__weight": scalar(attn_count), }) if scalar(corrupt_count) > 0.0: stats.update({ "acc_corrupt": scalar(correct_corrupt / corrupt_count.clamp_min(1.0)), "acc_corrupt__weight": scalar(corrupt_count), "loss_corrupt": float(raw_loss.detach().cpu()), "loss_corrupt__weight": scalar(corrupt_count), "wrong_frac": scalar(bridge.wrong_mask.float().sum() / corrupt_count.clamp_min(1.0)), "wrong_frac__weight": scalar(corrupt_count), }) init_pred = loss_state.detach().argmax(dim=-1) init_correct = ((init_pred == bridge.ids) & bridge.corrupt_mask).float().sum() stats.update({ "init_acc_corrupt": scalar(init_correct / corrupt_count.clamp_min(1.0)), "init_acc_corrupt__weight": scalar(corrupt_count), }) if args.log_t_bin_stats: total_corrupt = corrupt_count.clamp_min(1.0) sample_t = bridge.t.detach().float().view(-1) for lo, hi in zip(t_bin_edges[:-1], t_bin_edges[1:]): if hi >= t_bin_edges[-1]: sample_mask = (sample_t >= lo) & (sample_t <= hi) else: sample_mask = (sample_t >= lo) & (sample_t < hi) pos_mask = bridge.corrupt_mask & sample_mask[:, None] denom = pos_mask.float().sum() if scalar(denom) <= 0.0: continue tag = f"{lo:.1f}_{min(hi, 1.0):.1f}".replace(".", "p") correct_bin = ((pred == bridge.ids) & pos_mask).float().sum() stats[f"acc_corrupt_t_{tag}"] = scalar(correct_bin / denom.clamp_min(1.0)) stats[f"acc_corrupt_t_{tag}__weight"] = scalar(denom) stats[f"corrupt_frac_t_{tag}"] = scalar(denom / total_corrupt) stats[f"corrupt_frac_t_{tag}__weight"] = scalar(corrupt_count) model_for_stats = unwrap_model(trainable_model) out_linear = getattr(getattr(model_for_stats, "output_layer", None), "linear", None) if out_linear is None: out_linear = getattr(model_for_stats, "out_proj", None) if out_linear is not None: out_w = out_linear.weight.detach().float() out_g = out_linear.weight.grad stats["out_w_norm"] = float(out_w.norm().cpu()) stats["out_g_norm"] = float(out_g.detach().float().norm().cpu()) if out_g is not None else 0.0 if heavy_stats: stats.update(summarize_batch( loss_logits.detach(), bridge.ids, bridge.attn_mask, bridge.corrupt_mask, target_probs=bridge.target_probs if args.target_loss == "soft_ce" else None, include_loss=True, )) if args.log_t_bin_stats: stats.update(summarize_t_bin_acc( loss_logits.detach(), bridge.ids, bridge.corrupt_mask, bridge.t.detach(), edges=t_bin_edges, )) stats.update({ "init_gold_top10": float(masked_topk_acc(loss_state.detach(), bridge.ids, bridge.corrupt_mask, 10).cpu()), "init_gold_top100": float(masked_topk_acc(loss_state.detach(), bridge.ids, bridge.corrupt_mask, 100).cpu()), }) if rollout_train_sample_mask is not None: applied_pos = bridge.corrupt_mask & rollout_train_sample_mask.view(-1, 1) kept_pos = bridge.corrupt_mask & (~rollout_train_sample_mask.view(-1, 1)) stats.update({ "rollout_applied_pos_frac": float((applied_pos.float().sum() / bridge.corrupt_mask.float().sum().clamp_min(1.0)).detach().cpu()), "init_acc_rollout_applied": float(masked_acc(loss_state.detach(), bridge.ids, applied_pos).cpu()), "init_acc_rollout_kept": float(masked_acc(loss_state.detach(), bridge.ids, kept_pos).cpu()), "logit_acc_rollout_applied": float(masked_acc(loss_logits.detach(), bridge.ids, applied_pos).cpu()), "logit_acc_rollout_kept": float(masked_acc(loss_logits.detach(), bridge.ids, kept_pos).cpu()), }) running.append(stats) if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) optimizer.step() scheduler.step() if ema_state is not None and step >= args.ema_start_step: update_ema_state(ema_state, unwrap_model(trainable_model), args.ema_decay) optimizer.zero_grad(set_to_none=True) if step % args.log_every == 0: if rank_zero(rank): avg = average_stats_window(running) if running else {} elapsed = time.time() - start print(" ".join([ f"step={step}", f"micro_steps={step * args.grad_accum}", f"elapsed={elapsed:.1f}s", f"lr={scheduler.get_last_lr()[0]:.6e}", ] + [f"{k}={v:.4f}" for k, v in avg.items()]), flush=True) running = [] start = time.time() if rank_zero(rank) and args.eval_every > 0 and step % args.eval_every == 0: run_demo(args, unwrap_model(trainable_model), tokenizer, last_batch, device) if rank_zero(rank) and args.save_every > 0 and step % args.save_every == 0: save_checkpoint(save_dir / f"step_{step:07d}.pt", unwrap_model(trainable_model), optimizer, args, tokenizer, step, scheduler, ema_state) save_checkpoint(save_dir / "latest.pt", unwrap_model(trainable_model), optimizer, args, tokenizer, step, scheduler, ema_state) elif rank_zero(rank) and args.latest_every > 0 and step % args.latest_every == 0: save_checkpoint(save_dir / "latest.pt", unwrap_model(trainable_model), optimizer, args, tokenizer, step, scheduler, ema_state) if rank_zero(rank): save_checkpoint(save_dir / "latest.pt", unwrap_model(trainable_model), optimizer, args, tokenizer, args.total_steps, scheduler, ema_state) if ddp: dist.destroy_process_group() if __name__ == "__main__": main()