Hon-Wong's picture
Add files using upload-large-folder tool
dd7a3fd verified
Raw
History Blame Contribute Delete
95.1 kB
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import inspect
import json
import math
import os
import random
import time
from contextlib import nullcontext
from pathlib import Path
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sampler
from flowtext_lab.bridges import make_dirichlet_bridge_batch, make_prob_bridge_batch
try:
from flowtext_lab.bridges import make_gaussian_bridge_batch
except ImportError: # Optional legacy bridge; most current runs use dirichlet.
make_gaussian_bridge_batch = None
from flowtext_lab.data import (
CachedWrappedTextSequenceDataset,
EosPadCollator,
FixedPadCollator,
RecordPadTruncateTextSequenceDataset,
ShuffledWrappedStreamingTextSequenceDataset,
StreamingTextSequenceDataset,
TextSequenceDataset,
WrappedStreamingTextSequenceDataset,
)
from flowtext_lab.decode import fill_blank_init, soft_residual_decode
from flowtext_lab.metrics import (
masked_acc,
masked_ce,
masked_meanflow_l2,
masked_soft_ce,
masked_topk_acc,
summarize_batch,
summarize_t_bin_acc,
)
from flowtext_lab.model import EndpointPredictor
from flowtext_lab.text_detokenization import infer_detokenizer_name
from flowtext_lab.tokenization import BpeTextTokenizer
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser()
p.add_argument("--data_path", required=True)
p.add_argument("--text_column", default=None)
p.add_argument("--txt_record_mode", choices=["auto", "line", "eot"], default="auto")
p.add_argument("--detokenizer", default="auto")
p.add_argument("--openwebtext_split", choices=["all", "train_minus_100k", "valid_last_100k"], default="all")
p.add_argument("--tokenizer_path", required=True)
p.add_argument("--save_dir", required=True)
p.add_argument("--max_records", type=int, default=0)
p.add_argument(
"--elf_conditional_hf",
action="store_true",
help="Load ELF-style seq2seq Arrow datasets with condition_input_ids/input_ids columns.",
)
p.add_argument("--eval_data_path", default="")
p.add_argument("--dataset_cache_dir", default="")
p.add_argument("--max_input_len", type=int, default=64)
p.add_argument("--conditional_pad_token", choices=["pad", "eos"], default="eos")
p.add_argument("--label_drop_prob", type=float, default=0.0)
p.add_argument("--streaming", action="store_true")
p.add_argument(
"--record_pad_truncate",
action="store_true",
help=(
"ELF-style OWT input pipeline: one record -> one example, no global "
"stream packing, truncate to max_len and fixed-pad in the collator."
),
)
p.add_argument(
"--record_add_eos",
action="store_true",
help="Append tokenizer EOS/SEP to each record in --record_pad_truncate mode.",
)
p.add_argument(
"--record_add_special_tokens",
action="store_true",
help="Let the tokenizer add model special tokens in --record_pad_truncate mode.",
)
p.add_argument(
"--record_pad_token",
choices=["pad", "eos"],
default="pad",
help="Pad id for --record_pad_truncate. ELF-style uses tokenizer PAD when available.",
)
p.add_argument(
"--record_shuffle_buffer",
type=int,
default=10000,
help="Bounded online shuffle buffer for --record_pad_truncate; 0 disables it.",
)
p.add_argument("--wrap", action="store_true", help="Use wrapped packing: [BOS] + payload + [EOS].")
p.add_argument(
"--wrap_mode",
choices=["stream", "record"],
default="stream",
help="stream keeps the original fixed-width token stream; record packs whole records and only truncates overlength records.",
)
p.add_argument("--wrap_record_buffer_size", type=int, default=200)
p.add_argument(
"--owt_cached_chunks",
action="store_true",
help="OWT-only FLM/Duo-style cached chunk pool with epoch-level sampler shuffle.",
)
p.add_argument(
"--owt_chunk_cache_dir",
default="",
help="Cache directory for --owt_cached_chunks. Stores meta.json and chunks.i32.bin.",
)
p.add_argument("--owt_chunk_cache_rebuild", action="store_true")
p.add_argument("--owt_chunk_cache_write_batch", type=int, default=4096)
p.add_argument(
"--owt_exact_repeat_per_chunk",
type=int,
default=0,
help=(
"Pilot-only cached OWT sampler: repeat each cached chunk exactly this "
"many times, shuffle once deterministically, then shard across ranks."
),
)
p.add_argument(
"--online_chunk_shuffle",
action="store_true",
help="Use an online bounded shuffle buffer over wrapped stream chunks.",
)
p.add_argument("--online_chunk_shuffle_buffer", type=int, default=10000)
p.add_argument("--max_len", type=int, default=128)
p.add_argument("--stride", type=int, default=0)
p.add_argument("--batch_size", type=int, default=8)
p.add_argument("--num_workers", type=int, default=0)
p.add_argument("--dataloader_prefetch_factor", type=int, default=2)
p.add_argument("--blocking_data_transfer", action="store_true")
p.add_argument("--global_batch_size", type=int, default=0, help="If >0, overrides grad_accum with ceil(global_batch_size / batch_size).")
p.add_argument("--grad_accum", type=int, default=1)
p.add_argument("--total_steps", type=int, default=1000)
p.add_argument("--lr", type=float, default=6e-4)
p.add_argument("--weight_decay", type=float, default=0.1)
p.add_argument(
"--output_weight_decay",
type=float,
default=-1.0,
help="If >=0, use this AdamW weight decay only for the output projection weight.",
)
p.add_argument("--adam_beta1", type=float, default=0.9)
p.add_argument("--adam_beta2", type=float, default=0.95)
p.add_argument("--adam_eps", type=float, default=1e-8)
p.add_argument("--optimizer", choices=["adamw", "muon"], default="adamw")
p.add_argument("--muon_momentum", type=float, default=0.95)
p.add_argument("--muon_ns_steps", type=int, default=5)
p.add_argument("--muon_update_scale", type=float, default=1.0)
p.add_argument("--ema_decay", type=float, default=0.0)
p.add_argument("--ema_start_step", type=int, default=0)
p.add_argument("--warmup_steps", type=int, default=2000)
p.add_argument("--lr_schedule", choices=["constant_warmup", "cosine"], default="cosine")
p.add_argument("--min_lr", type=float, default=6e-5)
p.add_argument(
"--adamw_param_groups",
choices=["nanogpt", "all_decay"],
default="nanogpt",
help="nanoGPT decays only 2D matmul/embedding tensors; all_decay applies weight decay to every parameter.",
)
p.add_argument("--grad_clip", type=float, default=1.0)
p.add_argument("--seed", type=int, default=42)
p.add_argument("--d_model", type=int, default=384)
p.add_argument("--cond_dim", type=int, default=128)
p.add_argument("--n_layers", type=int, default=6)
p.add_argument("--n_heads", type=int, default=6)
p.add_argument("--dim_ff", type=int, default=1536)
p.add_argument("--dropout", type=float, default=0.0)
p.add_argument(
"--output_bias",
action=argparse.BooleanOptionalAction,
default=False,
help="Whether to include a vocabulary bias in the final output projection. Default false to avoid global-frequency basin amplification.",
)
p.add_argument(
"--norm_type",
choices=["rmsnorm", "layernorm"],
default="rmsnorm",
help="Normalization used by DDiT blocks/final head. Default rmsnorm follows modern ELF/T5-style practice.",
)
p.add_argument(
"--vocab_size_override",
type=int,
default=0,
help="Debug-only: train with a compact/remapped vocabulary of this size instead of tokenizer.vocab_size.",
)
p.add_argument("--model_type", choices=["transformer", "ddit"], default="transformer")
p.add_argument("--state_format", choices=["logprob", "prob"], default="logprob")
p.add_argument("--bridge", choices=["prob", "dirichlet", "gaussian"], default="prob")
p.add_argument("--target_loss", choices=["hard_ce", "soft_ce"], default="soft_ce")
p.add_argument("--meanflow_weight", type=float, default=0.0)
p.add_argument(
"--loss_t_weight_mode",
choices=["none", "linear_t", "quadratic_t", "one_minus_t_floor", "drop_low_t"],
default="none",
help="Per-sample t loss weighting. one_minus_t_floor downweights easy high-t targets.",
)
p.add_argument("--loss_t_min_weight", type=float, default=0.0)
p.add_argument("--loss_t_drop_below", type=float, default=0.2)
p.add_argument(
"--prior_center_loss_beta",
type=float,
default=0.0,
help="If >0, train CE on logits - beta * logits(prior_state, t), reducing unconditional prior attraction.",
)
p.add_argument("--prior_center_state", choices=["uniform"], default="uniform")
p.add_argument(
"--rollout_train_prob",
type=float,
default=0.0,
help=(
"Probability of replacing the supervised input state with a detached "
"self-rollout state before computing the reconstruction loss."
),
)
p.add_argument(
"--rollout_train_steps",
type=int,
default=1,
help="Number of no-grad self-rollout updates before the supervised loss.",
)
p.add_argument(
"--rollout_train_infer_steps",
type=int,
default=64,
help="Decode horizon used to set the one-step flowmap gamma during rollout training.",
)
p.add_argument(
"--rollout_train_temp",
type=float,
default=1.45,
help="Endpoint temperature used in the no-grad rollout-training update.",
)
p.add_argument(
"--rollout_train_max_gamma",
type=float,
default=1.0,
help="Clamp for the no-grad rollout-training flowmap gamma. Set <=0 to disable the clamp.",
)
p.add_argument(
"--rollout_train_corrupt_only",
action=argparse.BooleanOptionalAction,
default=True,
help="Only replace corrupted positions with the self-rollout state, preserving clean anchors.",
)
p.add_argument(
"--rollout_train_samplewise",
action=argparse.BooleanOptionalAction,
default=False,
help=(
"Apply rollout stochastically per sample instead of per micro-batch. "
"This always computes the rollout forward when prob>0, which keeps "
"Tensor Core work steady while preserving the requested rollout mix."
),
)
p.add_argument(
"--rollout_train_compute_always",
action=argparse.BooleanOptionalAction,
default=False,
help=(
"For batchwise rollout, always compute the detached rollout forward "
"even when the batch coin decides not to apply it. This preserves "
"the old rollout objective while increasing steady Tensor Core work."
),
)
p.add_argument(
"--rollout_train_debug_return_base",
action=argparse.BooleanOptionalAction,
default=False,
help=(
"Debug only: still run the detached rollout forward, but return the "
"original bridge state to isolate rollout-forward side effects from "
"the rollout-state training distribution."
),
)
p.add_argument("--target_prob", type=float, default=0.99)
p.add_argument("--min_t", type=float, default=0.0)
p.add_argument("--max_t", type=float, default=1.0)
p.add_argument(
"--t_sampling_mode",
choices=["uniform", "power_low", "power_high", "logit_uniform"],
default="uniform",
help=(
"How to sample model/flow time t. power_low uses u**gamma to "
"oversample low t; power_high uses 1-(1-u)**gamma to oversample high t; "
"logit_uniform samples uniformly in logit(t)."
),
)
p.add_argument("--t_sampling_power", type=float, default=1.0)
p.add_argument("--t_sampling_eps", type=float, default=1e-4)
p.add_argument("--dual_t", action="store_true", help="Use separate model/flow time and corruption/support time.")
p.add_argument("--corrupt_t_mode", choices=["same", "independent", "constant"], default="independent")
p.add_argument("--corrupt_t_value", type=float, default=0.0)
p.add_argument("--corrupt_min_t", type=float, default=None)
p.add_argument("--corrupt_max_t", type=float, default=None)
p.add_argument(
"--prefix_block_prob",
type=float,
default=0.0,
help="Train on grow-context states: expose only prefix plus one active block, corrupting the active block and masking future tokens.",
)
p.add_argument("--prefix_block_len", type=int, default=128)
p.add_argument("--min_mask_ratio", type=float, default=0.1)
p.add_argument("--max_mask_ratio", type=float, default=1.0)
p.add_argument(
"--mask_ratio_floor_schedule",
choices=["none", "one_minus_t"],
default="none",
help=(
"Optional per-sample lower-bound schedule for mask ratio. "
"one_minus_t uses max(min_mask_ratio, 1-t), so t=0 is all-mask "
"while high-t reverts to the configured min_mask_ratio."
),
)
p.add_argument(
"--mask_mixture_original_prob",
type=float,
default=0.0,
help="If any mask-mixture prob is >0, probability of using the original uniform mask-ratio sampler.",
)
p.add_argument(
"--mask_mixture_lowk_prob",
type=float,
default=0.0,
help="If any mask-mixture prob is >0, probability of keeping only K clean tokens and corrupting the rest.",
)
p.add_argument(
"--mask_mixture_lowcorrupt_prob",
type=float,
default=0.0,
help="If any mask-mixture prob is >0, probability of corrupting only K valid tokens and keeping the rest anchored.",
)
p.add_argument(
"--mask_mixture_block_prob",
type=float,
default=0.0,
help="If any mask-mixture prob is >0, probability of corrupting one contiguous valid-token block and keeping the rest anchored.",
)
p.add_argument(
"--mask_mixture_all_prob",
type=float,
default=0.0,
help="If any mask-mixture prob is >0, probability of corrupting every valid token.",
)
p.add_argument(
"--mask_mixture_lowk_clean_tokens",
default="1,2,4,8,16,32,64",
help="Comma-separated clean-token counts sampled uniformly for the low-K mixture branch.",
)
p.add_argument(
"--mask_mixture_lowcorrupt_tokens",
type=str,
default="1,2,4,8,16,32,64",
help="Comma-separated corrupt-token counts sampled uniformly for the low-corrupt-K mixture branch.",
)
p.add_argument(
"--mask_mixture_block_tokens",
type=str,
default="64,128",
help="Comma-separated contiguous block lengths sampled uniformly for the block-corrupt mixture branch.",
)
p.add_argument(
"--clean_state_mode",
choices=["onehot", "bridge"],
default="onehot",
help="State for non-corrupted support tokens: exact one-hot gold, or a same-t bridge sample around gold.",
)
p.add_argument("--wrong_token_replace_prob", default="0.0")
p.add_argument("--wrong_token_schedule", choices=["constant", "hard", "linear_t", "exp_t", "exp_k1"], default="constant")
p.add_argument("--wrong_token_exp_k", type=float, default=1.0)
p.add_argument("--dirichlet_concentration_min", type=float, default=1.0)
p.add_argument("--dirichlet_concentration_max", type=float, default=1024.0)
p.add_argument(
"--dirichlet_endpoint_mode",
choices=["bernoulli_wrong", "dual_t_line", "categorical_dual_t"],
default="bernoulli_wrong",
)
p.add_argument("--dirichlet_semantic_t_mode", choices=["same", "independent", "constant"], default="same")
p.add_argument("--dirichlet_semantic_t_value", type=float, default=0.0)
p.add_argument(
"--dirichlet_semantic_t_curve",
choices=["linear", "logit_power"],
default="linear",
help=(
"Transform semantic t before categorical_dual_t Bernoulli gold sampling. "
"logit_power uses p=t^gamma/(t^gamma+(1-t)^gamma)."
),
)
p.add_argument(
"--dirichlet_semantic_t_power",
type=float,
default=1.0,
help="Gamma for --dirichlet_semantic_t_curve logit_power.",
)
p.add_argument(
"--endpoint_sequence_random_prob_alpha",
type=float,
default=0.0,
help=(
"For categorical_dual_t: with probability alpha*(1-t), force all "
"corrupted positions in a sequence to use the random endpoint branch. "
"0 disables the sequence-level gate."
),
)
p.add_argument(
"--categorical_wrong_from_full_vocab",
action="store_true",
help="For categorical_dual_t only: sample the non-gold endpoint from the full vocab, allowing it to equal gold.",
)
p.add_argument(
"--categorical_wrong_from_batch_valid_tokens",
action="store_true",
help="For categorical_dual_t only: sample the non-gold endpoint from valid tokens in the current batch.",
)
p.add_argument(
"--categorical_wrong_basin_token_ids",
default="",
help=(
"For categorical_dual_t only: comma-separated token ids for a basin-bank wrong endpoint branch. "
"Used when any categorical_wrong_*_prob is >0."
),
)
p.add_argument("--categorical_wrong_basin_prob", type=float, default=0.0)
p.add_argument("--categorical_wrong_unigram_prob", type=float, default=0.0)
p.add_argument("--categorical_wrong_uniform_prob", type=float, default=0.0)
p.add_argument(
"--categorical_wrong_corpus_unigram_path",
default="",
help=(
"Optional torch/numpy file containing OWT corpus unigram counts or probabilities. "
"When set, categorical_wrong_unigram_prob samples from this corpus distribution."
),
)
p.add_argument(
"--categorical_wrong_corpus_unigram_alpha",
type=float,
default=1.0,
help="Exponent applied to corpus unigram counts before normalization; values <1 flatten high-frequency tokens.",
)
p.add_argument(
"--categorical_wrong_basin_shared_prob",
type=float,
default=0.0,
help="Per-sequence probability of sharing one basin-bank wrong token across all corrupted positions.",
)
p.add_argument(
"--categorical_wrong_unigram_shared_prob",
type=float,
default=0.0,
help="Per-sequence probability of sharing one corpus/batch-unigram wrong token across all corrupted positions.",
)
p.add_argument(
"--simplex_bridge_sampler",
choices=["dirichlet", "logistic_normal_linear_mean"],
default="dirichlet",
help="How to sample simplex bridge states after choosing the endpoint.",
)
p.add_argument("--logistic_normal_sigma_min", type=float, default=0.18)
p.add_argument("--logistic_normal_sigma_max", type=float, default=2.2)
p.add_argument("--logistic_normal_tau_min", type=float, default=0.65)
p.add_argument("--logistic_normal_tau_max", type=float, default=1.15)
p.add_argument("--dirichlet_scale", type=float, default=128.0)
p.add_argument("--dirichlet_floor", type=float, default=1e-3)
p.add_argument("--eps", type=float, default=1e-8)
p.add_argument("--log_every", type=int, default=10)
p.add_argument("--eval_every", type=int, default=100)
p.add_argument("--save_every", type=int, default=500)
p.add_argument("--latest_every", type=int, default=0, help="If >0, also refresh save_dir/latest.pt at this interval.")
p.add_argument("--resume_path", default="", help="Optional checkpoint path to resume model/optimizer/scheduler state.")
p.add_argument(
"--init_model_path",
default="",
help="Optional checkpoint path used only to initialize model weights; optimizer/scheduler/step start fresh.",
)
p.add_argument(
"--init_pos_embed_mode",
choices=["strict", "repeat", "interpolate"],
default="strict",
help="When --init_model_path has a different pos_embed length, adapt it to the current max_len.",
)
p.add_argument("--infer_steps", type=int, default=64)
p.add_argument("--decode_damping", type=float, default=1.0)
p.add_argument("--max_gamma", type=float, default=1.0)
p.add_argument("--decode_solver", choices=["simple", "flowmap"], default="flowmap")
p.add_argument("--noise_init", choices=["uniform", "logistic_normal"], default="logistic_normal")
p.add_argument("--bridge_noise_init", choices=["uniform", "logistic_normal"], default="logistic_normal")
p.add_argument("--noise_sigma", type=float, default=-1.0)
p.add_argument("--demo_mask_ratios", default="0.1,0.5,1.0")
p.add_argument("--demo_fill_t", type=float, default=0.0)
p.add_argument("--bf16", action="store_true")
p.add_argument("--allow_tf32", action=argparse.BooleanOptionalAction, default=True)
p.add_argument("--activation_checkpointing", action="store_true")
p.add_argument("--activation_checkpoint_interval", type=int, default=1)
p.add_argument("--activation_checkpoint_scope", choices=["block", "mlp"], default="block")
p.add_argument("--ddp_static_graph", action="store_true")
p.add_argument("--ddp_gradient_as_bucket_view", action=argparse.BooleanOptionalAction, default=True)
p.add_argument(
"--full_train_stats",
action="store_true",
help="Compute heavy train diagnostics every micro-step. Default computes cheap stats every optimizer step and heavy loss diagnostics on log steps.",
)
p.add_argument(
"--log_train_stats_each_step",
action=argparse.BooleanOptionalAction,
default=True,
help="Accumulate lightweight train diagnostics from every micro-batch, then report weighted averages over each log window.",
)
p.add_argument(
"--log_t_bin_stats",
action=argparse.BooleanOptionalAction,
default=True,
help="Log corrupt-token accuracy bucketed by model t, making mixed-t train logs easier to read.",
)
p.add_argument(
"--log_t_bin_edges",
default="0,0.2,0.4,0.6,0.8,1.0",
help="Comma-separated t-bin edges used when --log_t_bin_stats is enabled.",
)
p.add_argument("--torch_compile", action="store_true")
p.add_argument("--compile_mode", choices=["default", "reduce-overhead", "max-autotune"], default="max-autotune")
return p.parse_args()
def _datasetdict_to_single_split(ds):
if hasattr(ds, "keys") and not hasattr(ds, "column_names"):
for key in ("train", "validation", "test"):
if key in ds:
return ds[key]
return ds[next(iter(ds.keys()))]
return ds
def _load_arrow_dir_without_hf_metadata(path: str):
"""Fallback for ELF datasets saved with newer `List` feature metadata.
Some nodes have an older `datasets` build that cannot parse feature type
`List`, even though pyarrow can read the Arrow files just fine. We strip
only the HuggingFace schema metadata and let `datasets` infer equivalent
sequence columns from Arrow.
"""
import pyarrow as pa
import pyarrow.ipc as ipc
from datasets import Dataset
root = Path(path)
arrow_files = sorted(root.glob("data-*.arrow"))
if not arrow_files:
raise FileNotFoundError(f"No data-*.arrow files found under {path}")
tables = []
for arrow_path in arrow_files:
with pa.memory_map(str(arrow_path), "r") as source:
table = ipc.open_stream(source).read_all()
tables.append(table.cast(table.schema.remove_metadata()))
table = tables[0] if len(tables) == 1 else pa.concat_tables(tables, promote_options="default")
return Dataset(table)
def load_elf_conditional_dataset(path: str, max_records: int = 0):
from datasets import load_from_disk
try:
ds = _datasetdict_to_single_split(load_from_disk(path))
except ValueError as exc:
if "Feature type 'List' not found" not in str(exc):
raise
ds = _load_arrow_dir_without_hf_metadata(path)
if max_records and max_records > 0:
ds = ds.select(range(min(int(max_records), len(ds))))
required = {"condition_input_ids", "input_ids"}
missing = sorted(required.difference(ds.column_names))
if missing:
raise ValueError(f"ELF conditional dataset {path} missing columns: {missing}")
return ds
class ELFConditionalCollator:
def __init__(self, pad_id: int, max_len: int, max_input_len: int, *, loss_on_pad: bool) -> None:
self.pad_id = int(pad_id)
self.max_len = int(max_len)
self.max_input_len = int(max_input_len)
self.loss_on_pad = bool(loss_on_pad)
def __call__(self, examples: list[dict]) -> dict[str, torch.Tensor]:
ids_rows = []
attn_rows = []
loss_rows = []
cond_rows = []
for ex in examples:
cond = [int(x) for x in ex["condition_input_ids"][: self.max_input_len]]
target_budget = max(0, self.max_len - len(cond))
target = [int(x) for x in ex["input_ids"][:target_budget]]
seq = (cond + target)[: self.max_len]
cond_len = min(len(cond), self.max_len)
target_len = max(0, len(seq) - cond_len)
pad_len = self.max_len - len(seq)
if pad_len > 0:
seq = seq + [self.pad_id] * pad_len
attn = [True] * (cond_len + target_len) + [False] * pad_len
cond_mask = [True] * cond_len + [False] * (self.max_len - cond_len)
if self.loss_on_pad:
loss_mask = [False] * cond_len + [True] * (self.max_len - cond_len)
else:
loss_mask = [False] * cond_len + [True] * target_len + [False] * pad_len
ids_rows.append(seq)
attn_rows.append(attn)
loss_rows.append(loss_mask)
cond_rows.append(cond_mask)
return {
"ids": torch.tensor(ids_rows, dtype=torch.long),
"attn_mask": torch.tensor(attn_rows, dtype=torch.bool),
"loss_mask": torch.tensor(loss_rows, dtype=torch.bool),
"cond_seq_mask": torch.tensor(cond_rows, dtype=torch.bool),
}
class ExactRepeatDistributedSampler(Sampler[int]):
"""Deterministic finite schedule with exact per-item repeat counts."""
def __init__(
self,
dataset_size: int,
repeats_per_item: int,
*,
seed: int,
rank: int = 0,
world_size: int = 1,
) -> None:
self.dataset_size = int(dataset_size)
self.repeats_per_item = int(repeats_per_item)
self.seed = int(seed)
self.rank = int(rank)
self.world_size = int(world_size)
if self.dataset_size <= 0:
raise ValueError("dataset_size must be positive")
if self.repeats_per_item <= 0:
raise ValueError("repeats_per_item must be positive")
total = self.dataset_size * self.repeats_per_item
if total % self.world_size != 0:
raise ValueError(
f"exact repeat schedule total={total} is not divisible by world_size={self.world_size}"
)
self.num_samples = total // self.world_size
def __iter__(self):
import torch
indices = torch.arange(self.dataset_size, dtype=torch.long).repeat_interleave(
self.repeats_per_item
)
g = torch.Generator()
g.manual_seed(self.seed)
indices = indices[torch.randperm(indices.numel(), generator=g)]
return iter(indices[self.rank :: self.world_size].tolist())
def __len__(self) -> int:
return self.num_samples
def seed_all(seed: int) -> None:
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def sample_time(
batch: int,
min_t: float,
max_t: float,
device: torch.device,
mode: str = "uniform",
power: float = 1.0,
eps: float = 1e-4,
) -> torch.Tensor:
u = torch.rand(batch, device=device)
if mode == "uniform":
z = u
elif mode == "power_low":
z = u.pow(float(power))
elif mode == "power_high":
z = 1.0 - (1.0 - u).pow(float(power))
elif mode == "logit_uniform":
lo = torch.logit(torch.tensor(float(eps), device=device))
hi = torch.logit(torch.tensor(1.0 - float(eps), device=device))
z = torch.sigmoid(lo + (hi - lo) * u)
else:
raise ValueError(f"Unknown t_sampling_mode: {mode}")
return min_t + (max_t - min_t) * z
def sample_loss_t_weight(args: argparse.Namespace, model_t: torch.Tensor) -> torch.Tensor:
if args.loss_t_weight_mode == "none":
return torch.ones_like(model_t)
t = model_t.clamp(0.0, 1.0)
if args.loss_t_weight_mode == "linear_t":
weight = t
elif args.loss_t_weight_mode == "quadratic_t":
weight = t.square()
elif args.loss_t_weight_mode == "one_minus_t_floor":
weight = 1.0 - t
elif args.loss_t_weight_mode == "drop_low_t":
weight = torch.where(t < float(args.loss_t_drop_below), torch.zeros_like(t), torch.ones_like(t))
else:
raise ValueError(f"Unknown loss_t_weight_mode: {args.loss_t_weight_mode}")
if args.loss_t_min_weight > 0:
weight = weight.clamp_min(float(args.loss_t_min_weight))
return weight
def parse_t_bin_edges(raw: str) -> tuple[float, ...]:
vals = tuple(float(x) for x in raw.split(",") if x.strip())
if len(vals) < 2:
raise ValueError("--log_t_bin_edges must contain at least two comma-separated values")
if any(vals[i] >= vals[i + 1] for i in range(len(vals) - 1)):
raise ValueError(f"--log_t_bin_edges must be strictly increasing, got {raw!r}")
if vals[0] > 0.0 or vals[-1] < 1.0:
raise ValueError(f"--log_t_bin_edges must cover [0, 1], got {raw!r}")
return vals
def average_stats_window(running: list[dict[str, float]]) -> dict[str, float]:
keys: list[str] = []
seen: set[str] = set()
for stats in running:
for key in stats:
if key.endswith("__weight"):
continue
if key not in seen:
seen.add(key)
keys.append(key)
avg: dict[str, float] = {}
for key in keys:
weighted_sum = 0.0
weight_sum = 0.0
vals: list[float] = []
weight_key = f"{key}__weight"
for stats in running:
if key not in stats:
continue
val = float(stats[key])
if not math.isfinite(val):
continue
if weight_key in stats:
weight = float(stats[weight_key])
if math.isfinite(weight) and weight > 0.0:
weighted_sum += val * weight
weight_sum += weight
else:
vals.append(val)
if weight_sum > 0.0:
avg[key] = weighted_sum / weight_sum
elif vals:
avg[key] = sum(vals) / len(vals)
return avg
def weighted_masked_ce(logits: torch.Tensor, ids: torch.Tensor, mask: torch.Tensor, sample_weight: torch.Tensor) -> torch.Tensor:
per = torch.nn.functional.cross_entropy(logits.flatten(0, 1), ids.flatten(), reduction="none").view_as(ids)
weight = mask.float() * sample_weight.float().view(-1, 1)
return (per * weight).sum() / weight.sum().clamp_min(1.0)
def weighted_masked_soft_ce(
logits: torch.Tensor,
target_probs: torch.Tensor,
mask: torch.Tensor,
sample_weight: torch.Tensor,
) -> torch.Tensor:
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
per = -(target_probs.to(dtype=log_probs.dtype) * log_probs).sum(dim=-1).float()
weight = mask.float() * sample_weight.float().view(-1, 1)
return (per * weight).sum() / weight.sum().clamp_min(1.0)
def make_prior_center_state(
args: argparse.Namespace,
batch: int,
length: int,
vocab_size: int,
device: torch.device,
) -> torch.Tensor:
if args.prior_center_state != "uniform":
raise ValueError(f"Unknown prior_center_state: {args.prior_center_state}")
probs = torch.full((batch, length, vocab_size), 1.0 / vocab_size, dtype=torch.float32, device=device)
if args.state_format == "prob":
return probs
return torch.log(probs.clamp_min(args.eps))
def state_to_probs(args: argparse.Namespace, state: torch.Tensor) -> torch.Tensor:
if args.state_format == "prob":
probs = state.float()
else:
probs = state.float().exp()
probs = probs.clamp_min(args.eps)
return probs / probs.sum(dim=-1, keepdim=True).clamp_min(args.eps)
def probs_to_state(args: argparse.Namespace, probs: torch.Tensor) -> torch.Tensor:
probs = probs.float().clamp_min(args.eps)
probs = probs / probs.sum(dim=-1, keepdim=True).clamp_min(args.eps)
if args.state_format == "prob":
return probs
return torch.log(probs.clamp_min(args.eps))
def zero_param_anchor(model: torch.nn.Module, like: torch.Tensor) -> torch.Tensor:
anchor = like.new_zeros(())
for p in model.parameters():
if p.requires_grad:
anchor = anchor + p.reshape(-1)[0].float() * 0.0
return anchor
@torch.no_grad()
def maybe_make_rollout_train_state(
args: argparse.Namespace,
model: torch.nn.Module,
base_state: torch.Tensor,
model_t: torch.Tensor,
attn_mask: torch.Tensor,
corrupt_mask: torch.Tensor,
) -> tuple[torch.Tensor, float, torch.Tensor | None]:
prob = float(args.rollout_train_prob)
steps = int(args.rollout_train_steps)
if prob <= 0.0 or steps <= 0:
return base_state, 0.0, None
samplewise = bool(getattr(args, "rollout_train_samplewise", False))
if samplewise:
apply_sample = torch.rand(base_state.size(0), device=model_t.device) < prob
applied = float(apply_sample.float().mean().detach().cpu())
else:
apply_batch = prob >= 1.0 or float(torch.rand((), device=model_t.device).item()) < prob
if not apply_batch and not bool(getattr(args, "rollout_train_compute_always", False)):
return base_state, 0.0, None
apply_sample = torch.full((base_state.size(0),), bool(apply_batch), device=model_t.device)
applied = 1.0 if apply_batch else 0.0
probs = state_to_probs(args, base_state)
original_probs = probs
rollout_t = model_t.float().clamp(0.0, 1.0)
dt = 1.0 / max(1, int(args.rollout_train_infer_steps))
with torch.no_grad():
for _ in range(steps):
logits = model(probs_to_state(args, probs), rollout_t, attn_mask)
endpoint = torch.nn.functional.softmax(logits / float(args.rollout_train_temp), dim=-1)
del logits
gamma = dt / (1.0 - rollout_t).clamp_min(args.eps)
if float(args.rollout_train_max_gamma) > 0:
gamma = gamma.clamp_max(float(args.rollout_train_max_gamma))
probs = probs + gamma.view(-1, 1, 1) * (endpoint - probs)
probs = probs.clamp_min(args.eps)
probs = probs / probs.sum(dim=-1, keepdim=True).clamp_min(args.eps)
rollout_t = (rollout_t + dt).clamp_max(1.0)
if bool(getattr(args, "rollout_train_debug_return_base", False)):
return base_state, float(applied), apply_sample.detach()
if applied <= 0.0:
return base_state, 0.0, apply_sample.detach()
if args.rollout_train_corrupt_only:
apply_pos = corrupt_mask
apply_pos = apply_pos & apply_sample.view(-1, 1)
probs = torch.where(apply_pos.unsqueeze(-1), probs, original_probs)
else:
probs = torch.where(apply_sample.view(-1, 1, 1), probs, original_probs)
return probs_to_state(args, probs).detach(), float(applied), apply_sample.detach()
@torch.no_grad()
def zeropower_via_newtonschulz5(g: torch.Tensor, steps: int = 5) -> torch.Tensor:
"""Muon Newton-Schulz orthogonalization for matrix-shaped updates."""
orig_shape = g.shape
x = g.float().reshape(g.shape[0], -1)
transposed = False
if x.size(0) > x.size(1):
x = x.t()
transposed = True
x = x / (x.norm() + 1e-7)
a, b, c = 3.4445, -4.7750, 2.0315
for _ in range(max(1, int(steps))):
xx_t = x @ x.t()
x = a * x + (b * xx_t + c * (xx_t @ xx_t)) @ x
if transposed:
x = x.t()
return x.reshape(orig_shape).to(dtype=g.dtype)
class Muon(torch.optim.Optimizer):
"""Minimal Muon optimizer with Adam fallback for 1D params."""
def __init__(
self,
params,
*,
lr: float,
momentum: float = 0.95,
ns_steps: int = 5,
adam_betas: tuple[float, float] = (0.9, 0.95),
eps: float = 1e-8,
weight_decay: float = 0.0,
update_scale: float = 1.0,
) -> None:
params = [p for p in params if p.requires_grad]
muon_params = [p for p in params if p.dim() >= 2]
adam_params = [p for p in params if p.dim() < 2]
defaults = dict(
lr=lr,
momentum=momentum,
ns_steps=ns_steps,
adam_betas=adam_betas,
eps=eps,
weight_decay=weight_decay,
update_scale=update_scale,
)
groups = []
if muon_params:
groups.append({"params": muon_params, "use_muon": True, **defaults})
if adam_params:
groups.append({"params": adam_params, "use_muon": False, **defaults})
super().__init__(groups, defaults)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
lr = group["lr"]
wd = group["weight_decay"]
if group.get("use_muon", False):
momentum = group["momentum"]
ns_steps = group["ns_steps"]
update_scale = group["update_scale"]
for p in group["params"]:
if p.grad is None:
continue
g = p.grad
if g.is_sparse:
raise RuntimeError("Muon does not support sparse gradients")
state = self.state[p]
if not state:
state["momentum_buffer"] = torch.zeros_like(p)
buf = state["momentum_buffer"]
buf.mul_(momentum).add_(g, alpha=1.0 - momentum)
update = zeropower_via_newtonschulz5(buf, ns_steps)
if wd:
p.mul_(1.0 - lr * wd)
p.add_(update, alpha=-lr * update_scale)
else:
beta1, beta2 = group["adam_betas"]
eps = group["eps"]
for p in group["params"]:
if p.grad is None:
continue
g = p.grad
if g.is_sparse:
raise RuntimeError("Adam fallback does not support sparse gradients")
state = self.state[p]
if not state:
state["step"] = 0
state["exp_avg"] = torch.zeros_like(p)
state["exp_avg_sq"] = torch.zeros_like(p)
exp_avg = state["exp_avg"]
exp_avg_sq = state["exp_avg_sq"]
state["step"] += 1
step = state["step"]
if wd:
p.mul_(1.0 - lr * wd)
exp_avg.mul_(beta1).add_(g, alpha=1.0 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(g, g, value=1.0 - beta2)
bias_correction1 = 1.0 - beta1**step
bias_correction2 = 1.0 - beta2**step
step_size = lr * (bias_correction2**0.5) / bias_correction1
denom = exp_avg_sq.sqrt().add_(eps)
p.addcdiv_(exp_avg, denom, value=-step_size)
return loss
def configure_adamw_optimizer(
model: torch.nn.Module,
args: argparse.Namespace,
device: torch.device,
) -> torch.optim.Optimizer:
named_params = [(name, p) for name, p in model.named_parameters() if p.requires_grad]
params = [p for _, p in named_params]
if args.optimizer == "muon":
return Muon(
params,
lr=args.lr,
momentum=args.muon_momentum,
ns_steps=args.muon_ns_steps,
adam_betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps,
weight_decay=args.weight_decay,
update_scale=args.muon_update_scale,
)
fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters
extra_args = {"fused": True} if fused_available and device.type == "cuda" else {}
output_decay = float(getattr(args, "output_weight_decay", -1.0))
output_weight_names = ("output_layer.linear.weight", "out_proj.weight")
output_params = [
p
for name, p in named_params
if output_decay >= 0.0 and any(name.endswith(suffix) for suffix in output_weight_names)
]
output_param_ids = {id(p) for p in output_params}
if args.adamw_param_groups == "all_decay":
if output_params:
other_params = [p for p in params if id(p) not in output_param_ids]
optim_groups = [
{"params": other_params, "weight_decay": args.weight_decay},
{"params": output_params, "weight_decay": output_decay},
]
return torch.optim.AdamW(
optim_groups,
lr=args.lr,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps,
**extra_args,
)
return torch.optim.AdamW(
params,
lr=args.lr,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps,
weight_decay=args.weight_decay,
**extra_args,
)
# nanoGPT convention: decay matrix/embedding weights, not biases/norm/1D params.
decay_params = [p for _, p in named_params if p.dim() >= 2 and id(p) not in output_param_ids]
nodecay_params = [p for _, p in named_params if p.dim() < 2 and id(p) not in output_param_ids]
optim_groups = []
if decay_params:
optim_groups.append({"params": decay_params, "weight_decay": args.weight_decay})
if output_params:
optim_groups.append({"params": output_params, "weight_decay": output_decay})
if nodecay_params:
optim_groups.append({"params": nodecay_params, "weight_decay": 0.0})
return torch.optim.AdamW(
optim_groups,
lr=args.lr,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps,
**extra_args,
)
@torch.no_grad()
def init_ema_state(model: torch.nn.Module) -> dict[str, torch.Tensor]:
return {
k: v.detach().clone()
for k, v in model.state_dict().items()
if torch.is_floating_point(v)
}
@torch.no_grad()
def update_ema_state(ema_state: dict[str, torch.Tensor], model: torch.nn.Module, decay: float) -> None:
for k, v in model.state_dict().items():
if k in ema_state:
ema_state[k].mul_(decay).add_(v.detach(), alpha=1.0 - decay)
def resolve_bridge_force_t(args: argparse.Namespace, model_t: torch.Tensor) -> torch.Tensor | None:
if not args.dual_t or args.corrupt_t_mode == "same":
return model_t
if args.corrupt_t_mode == "constant":
return torch.full_like(model_t, float(args.corrupt_t_value))
return None
def load_corpus_unigram_probs(
path: str,
vocab_size: int,
alpha: float,
device: torch.device,
) -> torch.Tensor | None:
if not path:
return None
p = Path(path)
if not p.exists():
raise FileNotFoundError(f"categorical wrong corpus unigram file not found: {p}")
if p.suffix == ".npy":
import numpy as np
data = torch.from_numpy(np.load(p))
else:
loaded = torch.load(p, map_location="cpu")
if isinstance(loaded, dict):
for key in ("probs", "prob", "counts", "count", "unigram_counts"):
if key in loaded:
loaded = loaded[key]
break
data = torch.as_tensor(loaded)
probs = data.flatten().to(dtype=torch.float32, device="cpu")
if probs.numel() < vocab_size:
probs = F.pad(probs, (0, vocab_size - probs.numel()))
elif probs.numel() > vocab_size:
probs = probs[:vocab_size]
probs = probs.clamp_min(0.0)
if float(alpha) != 1.0:
probs = probs.pow(float(alpha))
total = probs.sum()
if not bool(torch.isfinite(total).item()) or float(total.item()) <= 0.0:
raise ValueError(f"corpus unigram file has no positive finite mass: {p}")
return (probs / total).to(device=device)
def make_bridge(
args: argparse.Namespace,
ids: torch.Tensor,
attn_mask: torch.Tensor,
vocab_size: int,
force_t: torch.Tensor | None = None,
force_corrupt_mask: torch.Tensor | None = None,
categorical_wrong_corpus_unigram_probs: torch.Tensor | None = None,
):
corrupt_min_t = args.min_t if args.corrupt_min_t is None else args.corrupt_min_t
corrupt_max_t = args.max_t if args.corrupt_max_t is None else args.corrupt_max_t
if args.bridge == "prob":
return make_prob_bridge_batch(
ids=ids,
attn_mask=attn_mask,
vocab_size=vocab_size,
target_prob=args.target_prob,
min_t=corrupt_min_t,
max_t=corrupt_max_t,
min_mask_ratio=args.min_mask_ratio,
max_mask_ratio=args.max_mask_ratio,
wrong_token_replace_prob=args.wrong_token_replace_prob,
wrong_token_schedule=args.wrong_token_schedule,
wrong_token_exp_k=args.wrong_token_exp_k,
eps=args.eps,
state_format=args.state_format,
noise_init=args.bridge_noise_init,
noise_sigma=args.noise_sigma,
force_t=force_t,
force_corrupt_mask=force_corrupt_mask,
mask_ratio_floor_schedule=args.mask_ratio_floor_schedule,
mask_mixture_original_prob=args.mask_mixture_original_prob,
mask_mixture_lowk_prob=args.mask_mixture_lowk_prob,
mask_mixture_lowcorrupt_prob=args.mask_mixture_lowcorrupt_prob,
mask_mixture_block_prob=args.mask_mixture_block_prob,
mask_mixture_all_prob=args.mask_mixture_all_prob,
mask_mixture_lowk_clean_tokens=args.mask_mixture_lowk_clean_tokens,
mask_mixture_lowcorrupt_tokens=args.mask_mixture_lowcorrupt_tokens,
mask_mixture_block_tokens=args.mask_mixture_block_tokens,
clean_state_mode=args.clean_state_mode,
)
if args.bridge == "gaussian":
if make_gaussian_bridge_batch is None:
raise RuntimeError("bridge=gaussian requested, but make_gaussian_bridge_batch is unavailable")
return make_gaussian_bridge_batch(
ids=ids,
attn_mask=attn_mask,
vocab_size=vocab_size,
target_prob=args.target_prob,
min_t=corrupt_min_t,
max_t=corrupt_max_t,
eps=args.eps,
force_t=force_t,
)
return make_dirichlet_bridge_batch(
ids=ids,
attn_mask=attn_mask,
vocab_size=vocab_size,
target_prob=args.target_prob,
min_t=corrupt_min_t,
max_t=corrupt_max_t,
min_mask_ratio=args.min_mask_ratio,
max_mask_ratio=args.max_mask_ratio,
wrong_token_replace_prob=args.wrong_token_replace_prob,
wrong_token_schedule=args.wrong_token_schedule,
wrong_token_exp_k=args.wrong_token_exp_k,
dirichlet_concentration_min=args.dirichlet_concentration_min,
dirichlet_concentration_max=args.dirichlet_concentration_max,
eps=args.eps,
state_format=args.state_format,
dirichlet_endpoint_mode=args.dirichlet_endpoint_mode,
dirichlet_semantic_t_mode=args.dirichlet_semantic_t_mode,
dirichlet_semantic_t_value=args.dirichlet_semantic_t_value,
dirichlet_semantic_t_curve=args.dirichlet_semantic_t_curve,
dirichlet_semantic_t_power=args.dirichlet_semantic_t_power,
endpoint_sequence_random_prob_alpha=args.endpoint_sequence_random_prob_alpha,
categorical_wrong_from_full_vocab=args.categorical_wrong_from_full_vocab,
categorical_wrong_from_batch_valid_tokens=args.categorical_wrong_from_batch_valid_tokens,
categorical_wrong_basin_token_ids=args.categorical_wrong_basin_token_ids,
categorical_wrong_basin_prob=args.categorical_wrong_basin_prob,
categorical_wrong_unigram_prob=args.categorical_wrong_unigram_prob,
categorical_wrong_uniform_prob=args.categorical_wrong_uniform_prob,
categorical_wrong_basin_shared_prob=args.categorical_wrong_basin_shared_prob,
categorical_wrong_unigram_shared_prob=args.categorical_wrong_unigram_shared_prob,
categorical_wrong_corpus_unigram_probs=categorical_wrong_corpus_unigram_probs,
simplex_bridge_sampler=args.simplex_bridge_sampler,
logistic_normal_sigma_min=args.logistic_normal_sigma_min,
logistic_normal_sigma_max=args.logistic_normal_sigma_max,
logistic_normal_tau_min=args.logistic_normal_tau_min,
logistic_normal_tau_max=args.logistic_normal_tau_max,
force_t=force_t,
force_corrupt_mask=force_corrupt_mask,
mask_ratio_floor_schedule=args.mask_ratio_floor_schedule,
mask_mixture_original_prob=args.mask_mixture_original_prob,
mask_mixture_lowk_prob=args.mask_mixture_lowk_prob,
mask_mixture_lowcorrupt_prob=args.mask_mixture_lowcorrupt_prob,
mask_mixture_block_prob=args.mask_mixture_block_prob,
mask_mixture_all_prob=args.mask_mixture_all_prob,
mask_mixture_lowk_clean_tokens=args.mask_mixture_lowk_clean_tokens,
mask_mixture_lowcorrupt_tokens=args.mask_mixture_lowcorrupt_tokens,
mask_mixture_block_tokens=args.mask_mixture_block_tokens,
clean_state_mode=args.clean_state_mode,
return_dense_targets=(args.target_loss == "soft_ce" or args.meanflow_weight > 0),
)
def maybe_make_prefix_block_masks(
args: argparse.Namespace,
attn_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor | None]:
prob = float(args.prefix_block_prob)
if prob <= 0.0:
return attn_mask, None
device = attn_mask.device
batch, length = attn_mask.shape
block_len = max(1, int(args.prefix_block_len))
use_prefix = torch.rand(batch, device=device) < prob
if not bool(use_prefix.any().item()):
return attn_mask, None
effective_attn = attn_mask.clone()
force_corrupt = torch.zeros_like(attn_mask, dtype=torch.bool)
for b in use_prefix.nonzero(as_tuple=False).flatten().tolist():
valid = attn_mask[b].nonzero(as_tuple=False).flatten()
if valid.numel() == 0:
continue
n_valid = int(valid.numel())
n_blocks = max(1, math.ceil(n_valid / block_len))
block_idx = int(torch.randint(0, n_blocks, (1,), device=device).item())
start = block_idx * block_len
end = min(start + block_len, n_valid)
force_corrupt[b, valid[start:end]] = True
if end < n_valid:
effective_attn[b, valid[end:]] = False
return effective_attn, force_corrupt
def dataloader_perf_kwargs(args: argparse.Namespace, device: torch.device) -> dict:
kwargs = {
"pin_memory": (device.type == "cuda" and args.num_workers > 0),
"persistent_workers": args.num_workers > 0,
}
if args.num_workers > 0 and args.dataloader_prefetch_factor > 0:
kwargs["prefetch_factor"] = int(args.dataloader_prefetch_factor)
return kwargs
def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
if isinstance(model, DistributedDataParallel):
model = model.module
return getattr(model, "_orig_mod", model)
def rank_zero(rank: int) -> bool:
return rank == 0
def save_checkpoint(
path: Path,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
args: argparse.Namespace,
tokenizer: BpeTextTokenizer,
step: int,
scheduler: torch.optim.lr_scheduler.LRScheduler | None = None,
ema_state: dict[str, torch.Tensor] | None = None,
) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
payload = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"args": vars(args),
"step": step,
"vocab_size": int(getattr(args, "effective_vocab_size", 0) or tokenizer.vocab_size),
"eos_id": tokenizer.eos_id,
}
if ema_state is not None:
model_state = model.state_dict()
payload["ema_model"] = {
k: ema_state[k].to(dtype=v.dtype, device="cpu") if k in ema_state else v.detach().cpu()
for k, v in model_state.items()
}
if scheduler is not None:
payload["scheduler"] = scheduler.state_dict()
tmp_path = path.with_suffix(path.suffix + ".tmp")
torch.save(payload, tmp_path)
tmp_path.replace(path)
@torch.no_grad()
def run_demo(args, model, tokenizer, batch, device) -> None:
model.eval()
ids = batch["ids"][:1].to(device)
attn_mask = batch["attn_mask"][:1].to(device)
target = tokenizer.decode(ids[0].tolist())
print(f"[demo] target: {target[:500]}", flush=True)
for ratio in [float(x) for x in args.demo_mask_ratios.split(",") if x.strip()]:
init, mask = fill_blank_init(
ids,
tokenizer.vocab_size,
args.target_prob,
ratio,
args.demo_fill_t,
args.eps,
noise_mode=args.noise_init,
noise_sigma=args.noise_sigma,
)
final = soft_residual_decode(
model,
init,
attn_mask,
args.infer_steps,
args.decode_damping,
args.max_gamma,
args.eps,
solver=args.decode_solver,
)
print(f"[demo mask_ratio={ratio:.2f}] init : {tokenizer.decode(init.argmax(-1)[0].tolist())[:500]}", flush=True)
print(f"[demo mask_ratio={ratio:.2f}] final: {tokenizer.decode(final.argmax(-1)[0].tolist())[:500]}", flush=True)
print(f"[demo mask_ratio={ratio:.2f}] corrupt_acc_init={float(((init.argmax(-1)==ids)&mask).float().sum()/mask.float().sum().clamp_min(1)):.4f} corrupt_acc_final={float(((final.argmax(-1)==ids)&mask).float().sum()/mask.float().sum().clamp_min(1)):.4f}", flush=True)
model.train()
def main() -> None:
args = parse_args()
t_bin_edges = parse_t_bin_edges(args.log_t_bin_edges)
world_size = int(os.environ.get("WORLD_SIZE", "1"))
rank = int(os.environ.get("RANK", "0"))
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
ddp = world_size > 1
if ddp:
if torch.cuda.is_available():
torch.cuda.set_device(local_rank)
dist.init_process_group(backend="nccl")
if args.global_batch_size > 0:
args.grad_accum = max(1, math.ceil(args.global_batch_size / max(1, args.batch_size * world_size)))
args.resolved_detokenizer = infer_detokenizer_name(raw_path=args.data_path, explicit=args.detokenizer)
seed_all(args.seed + rank)
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
if device.type == "cuda":
torch.backends.cuda.matmul.allow_tf32 = bool(args.allow_tf32)
torch.backends.cudnn.allow_tf32 = bool(args.allow_tf32)
def ddp_barrier() -> None:
if not ddp:
return
if device.type == "cuda":
dist.barrier(device_ids=[local_rank])
else:
dist.barrier()
tokenizer = BpeTextTokenizer.from_file(args.tokenizer_path)
model_vocab_size = int(args.vocab_size_override) if int(args.vocab_size_override) > 0 else int(tokenizer.vocab_size)
args.effective_vocab_size = model_vocab_size
categorical_wrong_corpus_unigram_probs = load_corpus_unigram_probs(
args.categorical_wrong_corpus_unigram_path,
model_vocab_size,
args.categorical_wrong_corpus_unigram_alpha,
device,
)
if rank_zero(rank) and categorical_wrong_corpus_unigram_probs is not None:
top_probs, top_ids = torch.topk(categorical_wrong_corpus_unigram_probs.detach().cpu(), k=min(10, model_vocab_size))
top_summary = ",".join(f"{int(i)}:{float(p):.3e}" for i, p in zip(top_ids, top_probs))
print(
"loaded_categorical_wrong_corpus_unigram="
f"{args.categorical_wrong_corpus_unigram_path} "
f"alpha={args.categorical_wrong_corpus_unigram_alpha} top={top_summary}",
flush=True,
)
shuffle_epoch_sampler = None
if args.elf_conditional_hf:
if args.wrap or args.owt_cached_chunks or args.online_chunk_shuffle or args.streaming or args.record_pad_truncate:
raise ValueError("--elf_conditional_hf is mutually exclusive with text streaming/packing modes")
if args.conditional_pad_token == "pad":
conditional_pad_id = tokenizer.pad_id if tokenizer.pad_id is not None else tokenizer.eos_id
loss_on_pad = False
else:
conditional_pad_id = tokenizer.eos_id
loss_on_pad = True
dataset = load_elf_conditional_dataset(args.data_path, max_records=args.max_records)
sampler = DistributedSampler(dataset, shuffle=True, seed=args.seed) if ddp else RandomSampler(dataset)
loader = DataLoader(
dataset,
batch_size=args.batch_size,
sampler=sampler,
collate_fn=ELFConditionalCollator(
conditional_pad_id,
max_len=args.max_len,
max_input_len=args.max_input_len,
loss_on_pad=loss_on_pad,
),
num_workers=args.num_workers,
**dataloader_perf_kwargs(args, device),
drop_last=True,
)
shuffle_epoch_sampler = sampler
dataset_size = (
f"elf_conditional_hf:{len(dataset)}"
f":max_input_len={args.max_input_len}:pad={conditional_pad_id}:loss_on_pad={int(loss_on_pad)}"
)
elif args.record_pad_truncate:
if args.wrap or args.owt_cached_chunks or args.online_chunk_shuffle:
raise ValueError("--record_pad_truncate is mutually exclusive with --wrap/--owt_cached_chunks/--online_chunk_shuffle")
if args.record_pad_token == "pad":
record_pad_id = tokenizer.pad_id if tokenizer.pad_id is not None else tokenizer.eos_id
else:
record_pad_id = tokenizer.eos_id
dataset = RecordPadTruncateTextSequenceDataset(
args.data_path,
tokenizer,
max_len=args.max_len,
text_column=args.text_column,
txt_record_mode=args.txt_record_mode,
openwebtext_split=args.openwebtext_split,
max_records_per_epoch=args.max_records,
detokenizer=args.resolved_detokenizer,
add_eos=args.record_add_eos,
add_special_tokens=args.record_add_special_tokens,
shuffle_buffer_size=args.record_shuffle_buffer,
seed=args.seed,
epoch=0,
)
shuffle_epoch_sampler = dataset
loader = DataLoader(
dataset,
batch_size=args.batch_size,
collate_fn=FixedPadCollator(record_pad_id, max_len=args.max_len),
num_workers=args.num_workers,
**dataloader_perf_kwargs(args, device),
drop_last=True,
)
dataset_size = (
"record_pad_truncate"
f":pad={record_pad_id}:add_eos={int(args.record_add_eos)}"
f":add_special={int(args.record_add_special_tokens)}"
f":shuffle_buffer={args.record_shuffle_buffer}"
)
elif args.owt_cached_chunks:
if not args.wrap or args.wrap_mode != "stream":
raise ValueError("--owt_cached_chunks requires --wrap --wrap_mode stream")
if not args.owt_chunk_cache_dir:
raise ValueError("--owt_cached_chunks requires --owt_chunk_cache_dir")
if args.openwebtext_split not in {"train_minus_100k", "all"}:
raise ValueError("--owt_cached_chunks is intended for OWT train/all splits")
if ddp and not rank_zero(rank):
ddp_barrier()
dataset = CachedWrappedTextSequenceDataset(
args.data_path,
tokenizer,
max_len=args.max_len,
cache_dir=args.owt_chunk_cache_dir,
text_column=args.text_column,
txt_record_mode=args.txt_record_mode,
openwebtext_split=args.openwebtext_split,
max_records=args.max_records,
detokenizer=args.resolved_detokenizer,
rebuild=args.owt_chunk_cache_rebuild,
build_cache=rank_zero(rank),
write_batch_size=args.owt_chunk_cache_write_batch,
)
if ddp and rank_zero(rank):
ddp_barrier()
if args.owt_exact_repeat_per_chunk > 0:
sampler = ExactRepeatDistributedSampler(
len(dataset),
args.owt_exact_repeat_per_chunk,
seed=args.seed,
rank=rank,
world_size=world_size,
)
shuffle_epoch_sampler = None
else:
sampler = DistributedSampler(dataset, shuffle=True, seed=args.seed) if ddp else RandomSampler(dataset)
shuffle_epoch_sampler = sampler
loader = DataLoader(
dataset,
batch_size=args.batch_size,
sampler=sampler,
collate_fn=EosPadCollator(tokenizer.eos_id, max_len=args.max_len),
num_workers=args.num_workers,
**dataloader_perf_kwargs(args, device),
drop_last=False,
)
dataset_size = f"owt_cached_chunks:{len(dataset)}"
elif args.wrap and args.online_chunk_shuffle:
if args.wrap_mode != "stream":
raise ValueError("--online_chunk_shuffle currently requires --wrap_mode stream")
dataset = ShuffledWrappedStreamingTextSequenceDataset(
args.data_path,
tokenizer,
max_len=args.max_len,
text_column=args.text_column,
txt_record_mode=args.txt_record_mode,
openwebtext_split=args.openwebtext_split,
max_records_per_epoch=args.max_records,
detokenizer=args.resolved_detokenizer,
chunk_buffer_size=args.online_chunk_shuffle_buffer,
seed=args.seed,
epoch=0,
)
shuffle_epoch_sampler = dataset
loader = DataLoader(
dataset,
batch_size=args.batch_size,
collate_fn=EosPadCollator(tokenizer.eos_id, max_len=args.max_len),
num_workers=args.num_workers,
**dataloader_perf_kwargs(args, device),
drop_last=False,
)
dataset_size = f"wrapped_stream_online_shuffle:{args.online_chunk_shuffle_buffer}"
elif args.wrap:
dataset = WrappedStreamingTextSequenceDataset(
args.data_path,
tokenizer,
max_len=args.max_len,
text_column=args.text_column,
txt_record_mode=args.txt_record_mode,
openwebtext_split=args.openwebtext_split,
max_records_per_epoch=args.max_records,
detokenizer=args.resolved_detokenizer,
wrap_mode=args.wrap_mode,
wrap_record_buffer_size=args.wrap_record_buffer_size,
)
loader = DataLoader(
dataset,
batch_size=args.batch_size,
collate_fn=EosPadCollator(tokenizer.eos_id, max_len=args.max_len),
num_workers=args.num_workers,
**dataloader_perf_kwargs(args, device),
drop_last=False,
)
dataset_size = f"wrapped_{args.wrap_mode}"
elif args.streaming:
dataset = StreamingTextSequenceDataset(
args.data_path,
tokenizer,
max_len=args.max_len,
text_column=args.text_column,
txt_record_mode=args.txt_record_mode,
openwebtext_split=args.openwebtext_split,
max_records_per_epoch=args.max_records,
stride=args.stride,
detokenizer=args.resolved_detokenizer,
)
loader = DataLoader(
dataset,
batch_size=args.batch_size,
collate_fn=EosPadCollator(tokenizer.eos_id, max_len=args.max_len),
num_workers=args.num_workers,
**dataloader_perf_kwargs(args, device),
drop_last=False,
)
dataset_size = "streaming"
else:
dataset = TextSequenceDataset(
args.data_path,
tokenizer,
max_len=args.max_len,
text_column=args.text_column,
txt_record_mode=args.txt_record_mode,
openwebtext_split=args.openwebtext_split,
max_records=args.max_records,
stride=args.stride,
detokenizer=args.resolved_detokenizer,
)
sampler = DistributedSampler(dataset, shuffle=True, seed=args.seed) if ddp else RandomSampler(dataset)
loader = DataLoader(
dataset,
batch_size=args.batch_size,
sampler=sampler,
collate_fn=EosPadCollator(tokenizer.eos_id, max_len=args.max_len),
num_workers=args.num_workers,
**dataloader_perf_kwargs(args, device),
drop_last=False,
)
dataset_size = len(dataset)
model = EndpointPredictor(
vocab_size=model_vocab_size,
max_len=args.max_len,
d_model=args.d_model,
n_heads=args.n_heads,
n_layers=args.n_layers,
dim_ff=args.dim_ff,
dropout=args.dropout,
model_type=args.model_type,
cond_dim=args.cond_dim,
input_format=args.state_format,
output_bias=args.output_bias,
norm_type=args.norm_type,
).to(device)
if args.activation_checkpointing and hasattr(model, "set_activation_checkpointing"):
model.set_activation_checkpointing(
True,
interval=args.activation_checkpoint_interval,
scope=args.activation_checkpoint_scope,
)
if args.torch_compile:
model = torch.compile(model, mode=args.compile_mode if args.compile_mode != "default" else None)
trainable_model = model
if ddp:
trainable_model = DistributedDataParallel(
model,
device_ids=[local_rank],
output_device=local_rank,
static_graph=bool(args.ddp_static_graph),
gradient_as_bucket_view=bool(args.ddp_gradient_as_bucket_view),
)
optimizer = configure_adamw_optimizer(unwrap_model(trainable_model), args, device)
ema_state = init_ema_state(unwrap_model(trainable_model)) if args.ema_decay > 0 else None
def lr_lambda(step: int) -> float:
if args.warmup_steps > 0 and step < args.warmup_steps:
return max(1e-12, float(step + 1) / float(args.warmup_steps))
if args.lr_schedule == "constant_warmup":
return 1.0
progress = (step - args.warmup_steps) / max(1, args.total_steps - args.warmup_steps)
min_lr_ratio = float(args.min_lr) / max(float(args.lr), 1e-12)
return max(min_lr_ratio, 0.5 * (1.0 + math.cos(math.pi * min(max(progress, 0.0), 1.0))))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
start_step = 1
if args.init_model_path:
if args.resume_path:
raise ValueError("--init_model_path is mutually exclusive with --resume_path")
ckpt = torch.load(args.init_model_path, map_location=device)
init_state = dict(ckpt["model"])
model_state = unwrap_model(trainable_model).state_dict()
if "pos_embed" in init_state and "pos_embed" in model_state and init_state["pos_embed"].shape != model_state["pos_embed"].shape:
if args.init_pos_embed_mode == "strict":
raise ValueError(
f"init pos_embed shape {tuple(init_state['pos_embed'].shape)} != "
f"model pos_embed shape {tuple(model_state['pos_embed'].shape)}; "
"set --init_pos_embed_mode repeat or interpolate to adapt"
)
src = init_state["pos_embed"]
target_len = int(model_state["pos_embed"].shape[1])
if src.shape[1] > target_len:
init_state["pos_embed"] = src[:, :target_len].contiguous()
elif args.init_pos_embed_mode == "repeat":
reps = math.ceil(target_len / int(src.shape[1]))
init_state["pos_embed"] = src.repeat(1, reps, 1)[:, :target_len].contiguous()
else:
init_state["pos_embed"] = F.interpolate(
src.transpose(1, 2),
size=target_len,
mode="linear",
align_corners=True,
).transpose(1, 2).contiguous()
if rank_zero(rank):
print(
f"adapted_init_pos_embed from={tuple(src.shape)} to={tuple(init_state['pos_embed'].shape)} "
f"mode={args.init_pos_embed_mode}",
flush=True,
)
for optional_bias_key in ("output_layer.linear.bias", "out_proj.bias"):
if optional_bias_key in init_state and optional_bias_key not in model_state:
init_state.pop(optional_bias_key)
if rank_zero(rank):
print(f"dropped_init_output_bias key={optional_bias_key} because current model has output_bias=False", flush=True)
unwrap_model(trainable_model).load_state_dict(init_state)
if ema_state is not None and "ema_model" in ckpt:
for k, v in ckpt["ema_model"].items():
if k in ema_state:
ema_state[k].copy_(v.to(device=ema_state[k].device, dtype=ema_state[k].dtype))
if rank_zero(rank):
print(f"initialized_model_from={args.init_model_path} start_step={start_step}", flush=True)
if args.resume_path:
ckpt = torch.load(args.resume_path, map_location=device)
unwrap_model(trainable_model).load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])
if "scheduler" in ckpt:
scheduler.load_state_dict(ckpt["scheduler"])
else:
scheduler.last_epoch = int(ckpt.get("step", 0))
if ema_state is not None and "ema_model" in ckpt:
for k, v in ckpt["ema_model"].items():
if k in ema_state:
ema_state[k].copy_(v.to(device=ema_state[k].device, dtype=ema_state[k].dtype))
start_step = int(ckpt.get("step", 0)) + 1
if rank_zero(rank):
print(f"resumed_from={args.resume_path} start_step={start_step}", flush=True)
save_dir = Path(args.save_dir)
if rank_zero(rank):
save_dir.mkdir(parents=True, exist_ok=True)
(save_dir / "args.json").write_text(json.dumps(vars(args), indent=2), encoding="utf-8")
print(json.dumps({
"device": str(device),
"rank": rank,
"world_size": world_size,
"samples": dataset_size,
"vocab_size": model_vocab_size,
"tokenizer_vocab_size": tokenizer.vocab_size,
"save_dir": str(save_dir),
"batch_size": args.batch_size,
"grad_accum": args.grad_accum,
"effective_batch_size": args.batch_size * args.grad_accum * world_size,
"global_batch_size": args.global_batch_size,
"lr_schedule": args.lr_schedule,
"optimizer": args.optimizer,
"warmup_steps": args.warmup_steps,
"min_lr": args.min_lr,
"weight_decay": args.weight_decay,
"output_weight_decay": args.output_weight_decay,
"adamw_param_groups": args.adamw_param_groups,
"adam_beta1": args.adam_beta1,
"adam_beta2": args.adam_beta2,
"adam_eps": args.adam_eps,
"muon_momentum": args.muon_momentum,
"muon_ns_steps": args.muon_ns_steps,
"muon_update_scale": args.muon_update_scale,
"ema_decay": args.ema_decay,
"ema_start_step": args.ema_start_step,
"model_type": args.model_type,
"output_bias": args.output_bias,
"norm_type": args.norm_type,
"t_sampling_mode": args.t_sampling_mode,
"t_sampling_power": args.t_sampling_power,
"t_sampling_eps": args.t_sampling_eps,
"dual_t": args.dual_t,
"corrupt_t_mode": args.corrupt_t_mode,
"corrupt_min_t": args.corrupt_min_t,
"corrupt_max_t": args.corrupt_max_t,
"prefix_block_prob": args.prefix_block_prob,
"prefix_block_len": args.prefix_block_len,
"mask_ratio_floor_schedule": args.mask_ratio_floor_schedule,
"dirichlet_endpoint_mode": args.dirichlet_endpoint_mode,
"dirichlet_semantic_t_mode": args.dirichlet_semantic_t_mode,
"dirichlet_semantic_t_value": args.dirichlet_semantic_t_value,
"dirichlet_semantic_t_curve": args.dirichlet_semantic_t_curve,
"dirichlet_semantic_t_power": args.dirichlet_semantic_t_power,
"endpoint_sequence_random_prob_alpha": args.endpoint_sequence_random_prob_alpha,
"categorical_wrong_from_full_vocab": args.categorical_wrong_from_full_vocab,
"categorical_wrong_from_batch_valid_tokens": args.categorical_wrong_from_batch_valid_tokens,
"categorical_wrong_basin_token_ids": args.categorical_wrong_basin_token_ids,
"categorical_wrong_basin_prob": args.categorical_wrong_basin_prob,
"categorical_wrong_unigram_prob": args.categorical_wrong_unigram_prob,
"categorical_wrong_uniform_prob": args.categorical_wrong_uniform_prob,
"categorical_wrong_corpus_unigram_path": args.categorical_wrong_corpus_unigram_path,
"categorical_wrong_corpus_unigram_alpha": args.categorical_wrong_corpus_unigram_alpha,
"categorical_wrong_basin_shared_prob": args.categorical_wrong_basin_shared_prob,
"categorical_wrong_unigram_shared_prob": args.categorical_wrong_unigram_shared_prob,
"mask_mixture_original_prob": args.mask_mixture_original_prob,
"mask_mixture_lowk_prob": args.mask_mixture_lowk_prob,
"mask_mixture_lowcorrupt_prob": args.mask_mixture_lowcorrupt_prob,
"mask_mixture_block_prob": args.mask_mixture_block_prob,
"mask_mixture_all_prob": args.mask_mixture_all_prob,
"mask_mixture_lowk_clean_tokens": args.mask_mixture_lowk_clean_tokens,
"mask_mixture_lowcorrupt_tokens": args.mask_mixture_lowcorrupt_tokens,
"mask_mixture_block_tokens": args.mask_mixture_block_tokens,
"simplex_bridge_sampler": args.simplex_bridge_sampler,
"logistic_normal_sigma_min": args.logistic_normal_sigma_min,
"logistic_normal_sigma_max": args.logistic_normal_sigma_max,
"logistic_normal_tau_min": args.logistic_normal_tau_min,
"logistic_normal_tau_max": args.logistic_normal_tau_max,
"torch_compile": args.torch_compile,
"compile_mode": args.compile_mode,
"state_format": args.state_format,
"target_loss": args.target_loss,
"meanflow_weight": args.meanflow_weight,
"rollout_train_prob": args.rollout_train_prob,
"rollout_train_steps": args.rollout_train_steps,
"rollout_train_infer_steps": args.rollout_train_infer_steps,
"rollout_train_temp": args.rollout_train_temp,
"rollout_train_max_gamma": args.rollout_train_max_gamma,
"rollout_train_corrupt_only": args.rollout_train_corrupt_only,
"rollout_train_samplewise": args.rollout_train_samplewise,
"rollout_train_compute_always": args.rollout_train_compute_always,
"bridge_noise_init": args.bridge_noise_init,
"noise_sigma": args.noise_sigma,
"allow_tf32": args.allow_tf32,
"activation_checkpointing": args.activation_checkpointing,
"activation_checkpoint_interval": args.activation_checkpoint_interval,
"activation_checkpoint_scope": args.activation_checkpoint_scope,
"ddp_static_graph": args.ddp_static_graph,
"ddp_gradient_as_bucket_view": args.ddp_gradient_as_bucket_view,
"blocking_data_transfer": args.blocking_data_transfer,
"dataloader_prefetch_factor": args.dataloader_prefetch_factor,
"full_train_stats": args.full_train_stats,
"record_pad_truncate": args.record_pad_truncate,
"record_add_eos": args.record_add_eos,
"record_add_special_tokens": args.record_add_special_tokens,
"record_pad_token": args.record_pad_token,
"record_shuffle_buffer": args.record_shuffle_buffer,
"wrap": args.wrap,
"wrap_mode": args.wrap_mode,
"wrap_record_buffer_size": args.wrap_record_buffer_size,
"owt_cached_chunks": args.owt_cached_chunks,
"owt_chunk_cache_dir": args.owt_chunk_cache_dir,
"owt_chunk_cache_rebuild": args.owt_chunk_cache_rebuild,
"owt_chunk_cache_write_batch": args.owt_chunk_cache_write_batch,
"owt_exact_repeat_per_chunk": args.owt_exact_repeat_per_chunk,
"online_chunk_shuffle": args.online_chunk_shuffle,
"online_chunk_shuffle_buffer": args.online_chunk_shuffle_buffer,
"openwebtext_split": args.openwebtext_split,
"detokenizer": args.detokenizer,
"resolved_detokenizer": args.resolved_detokenizer,
"num_workers": args.num_workers,
"latest_every": args.latest_every,
"resume_path": args.resume_path,
}, indent=2), flush=True)
use_amp = args.bf16 and device.type == "cuda"
data_epoch = 0
data_iter = iter(loader)
running = []
start = time.time()
trainable_model.train()
optimizer.zero_grad(set_to_none=True)
for step in range(start_step, args.total_steps + 1):
last_batch = None
for micro_step in range(args.grad_accum):
try:
batch = next(data_iter)
except StopIteration:
data_epoch += 1
if shuffle_epoch_sampler is not None and hasattr(shuffle_epoch_sampler, "set_epoch"):
shuffle_epoch_sampler.set_epoch(data_epoch)
data_iter = iter(loader)
batch = next(data_iter)
last_batch = batch
non_blocking = device.type == "cuda" and not args.blocking_data_transfer
ids = batch["ids"].to(device, non_blocking=non_blocking)
attn_mask = batch["attn_mask"].to(device, non_blocking=non_blocking)
loss_mask = batch.get("loss_mask")
cond_seq_mask = batch.get("cond_seq_mask")
if loss_mask is not None:
loss_mask = loss_mask.to(device, non_blocking=non_blocking)
if cond_seq_mask is not None:
cond_seq_mask = cond_seq_mask.to(device, non_blocking=non_blocking)
bridge_input_mask = loss_mask if loss_mask is not None else attn_mask
bridge_attn_mask, force_corrupt_mask = maybe_make_prefix_block_masks(args, bridge_input_mask)
model_t = sample_time(
ids.size(0),
args.min_t,
args.max_t,
device,
mode=args.t_sampling_mode,
power=args.t_sampling_power,
eps=args.t_sampling_eps,
)
bridge = make_bridge(
args,
ids,
bridge_attn_mask,
model_vocab_size,
force_t=resolve_bridge_force_t(args, model_t),
force_corrupt_mask=force_corrupt_mask,
categorical_wrong_corpus_unigram_probs=categorical_wrong_corpus_unigram_probs,
)
if loss_mask is not None:
model_attn_mask = attn_mask
if cond_seq_mask is not None and float(args.label_drop_prob) > 0.0:
drop = torch.rand(ids.size(0), device=device) < float(args.label_drop_prob)
model_attn_mask = attn_mask & ~(cond_seq_mask & drop.view(-1, 1))
bridge.attn_mask = model_attn_mask
loss_t_weight = sample_loss_t_weight(args, model_t)
sync_grad = micro_step == args.grad_accum - 1
sync_context = (
trainable_model.no_sync()
if ddp
and isinstance(trainable_model, DistributedDataParallel)
and not args.ddp_static_graph
and not sync_grad
else nullcontext()
)
with sync_context:
grad_enabled_before_rollout = torch.is_grad_enabled()
loss_state, rollout_train_applied, rollout_train_sample_mask = maybe_make_rollout_train_state(
args,
trainable_model,
bridge.state,
model_t,
bridge.attn_mask,
bridge.corrupt_mask,
)
grad_enabled_after_rollout = torch.is_grad_enabled()
with torch.amp.autocast("cuda", enabled=use_amp, dtype=torch.bfloat16):
logits = trainable_model(loss_state, model_t, bridge.attn_mask)
logits_requires_grad = bool(logits.requires_grad)
loss_logits = logits
if args.prior_center_loss_beta > 0:
prior_state = make_prior_center_state(
args,
ids.size(0),
ids.size(1),
model_vocab_size,
device,
)
with torch.no_grad():
prior_logits = trainable_model(prior_state, model_t, bridge.attn_mask)
loss_logits = logits.float() - float(args.prior_center_loss_beta) * prior_logits.float()
if args.target_loss == "soft_ce":
if args.loss_t_weight_mode == "none":
raw_loss = masked_soft_ce(loss_logits, bridge.target_probs, bridge.corrupt_mask)
else:
raw_loss = weighted_masked_soft_ce(loss_logits, bridge.target_probs, bridge.corrupt_mask, loss_t_weight)
else:
if args.loss_t_weight_mode == "none":
raw_loss = masked_ce(loss_logits, bridge.ids, bridge.corrupt_mask)
else:
raw_loss = weighted_masked_ce(loss_logits, bridge.ids, bridge.corrupt_mask, loss_t_weight)
raw_loss_requires_grad = bool(raw_loss.requires_grad)
if args.meanflow_weight > 0:
if args.state_format == "prob":
current_probs = loss_state
else:
current_probs = loss_state.float().exp()
meanflow_loss = masked_meanflow_l2(
logits,
current_probs=current_probs,
target_probs=bridge.target_probs,
mask=bridge.corrupt_mask,
)
total_loss = raw_loss + float(args.meanflow_weight) * meanflow_loss
else:
meanflow_loss = raw_loss.new_zeros(())
total_loss = raw_loss
if float(args.rollout_train_prob) > 0.0:
total_loss = total_loss + zero_param_anchor(unwrap_model(trainable_model), total_loss)
loss = total_loss / max(1, args.grad_accum)
loss.backward()
last_micro_step = micro_step == args.grad_accum - 1
heavy_stats = args.full_train_stats or (step % args.log_every == 0 and last_micro_step)
collect_stats = rank_zero(rank) and (
args.full_train_stats
or args.log_train_stats_each_step
or heavy_stats
)
if collect_stats:
# Keep window logging cheap: aggregate accuracy/count diagnostics
# from every micro-batch, while expensive extras stay log-step only.
def scalar(x: torch.Tensor) -> float:
return float(x.detach().float().cpu())
batch_weight = float(ids.size(0))
corrupt_count = bridge.corrupt_mask.float().sum()
stats = {
"loss": float(total_loss.detach().cpu()),
"loss__weight": scalar(corrupt_count),
"loss_recon": float(raw_loss.detach().cpu()),
"loss_recon__weight": scalar(corrupt_count),
"loss_meanflow": float(meanflow_loss.detach().cpu()),
"loss_meanflow__weight": scalar(corrupt_count),
"mean_model_t": float(model_t.mean().detach().cpu()),
"mean_model_t__weight": batch_weight,
"mean_corrupt_t": float(bridge.t.mean().detach().cpu()),
"mean_corrupt_t__weight": batch_weight,
"mean_loss_t_weight": float(loss_t_weight.mean().detach().cpu()),
"mean_loss_t_weight__weight": batch_weight,
"prior_center_loss_beta": float(args.prior_center_loss_beta),
"rollout_train_applied": float(rollout_train_applied),
"rollout_train_applied__weight": batch_weight,
"grad_enabled_before_rollout": float(grad_enabled_before_rollout),
"grad_enabled_after_rollout": float(grad_enabled_after_rollout),
"logits_requires_grad": float(logits_requires_grad),
"raw_loss_requires_grad": float(raw_loss_requires_grad),
}
with torch.no_grad():
pred = loss_logits.detach().argmax(dim=-1)
attn_mask_f = bridge.attn_mask.float()
corrupt_mask_f = bridge.corrupt_mask.float()
attn_count = attn_mask_f.sum()
correct_all = ((pred == bridge.ids) & bridge.attn_mask).float().sum()
correct_corrupt = ((pred == bridge.ids) & bridge.corrupt_mask).float().sum()
if scalar(attn_count) > 0.0:
stats.update({
"acc_all": scalar(correct_all / attn_count.clamp_min(1.0)),
"acc_all__weight": scalar(attn_count),
"corrupt_frac": scalar(corrupt_count / attn_count.clamp_min(1.0)),
"corrupt_frac__weight": scalar(attn_count),
})
if scalar(corrupt_count) > 0.0:
stats.update({
"acc_corrupt": scalar(correct_corrupt / corrupt_count.clamp_min(1.0)),
"acc_corrupt__weight": scalar(corrupt_count),
"loss_corrupt": float(raw_loss.detach().cpu()),
"loss_corrupt__weight": scalar(corrupt_count),
"wrong_frac": scalar(bridge.wrong_mask.float().sum() / corrupt_count.clamp_min(1.0)),
"wrong_frac__weight": scalar(corrupt_count),
})
init_pred = loss_state.detach().argmax(dim=-1)
init_correct = ((init_pred == bridge.ids) & bridge.corrupt_mask).float().sum()
stats.update({
"init_acc_corrupt": scalar(init_correct / corrupt_count.clamp_min(1.0)),
"init_acc_corrupt__weight": scalar(corrupt_count),
})
if args.log_t_bin_stats:
total_corrupt = corrupt_count.clamp_min(1.0)
sample_t = bridge.t.detach().float().view(-1)
for lo, hi in zip(t_bin_edges[:-1], t_bin_edges[1:]):
if hi >= t_bin_edges[-1]:
sample_mask = (sample_t >= lo) & (sample_t <= hi)
else:
sample_mask = (sample_t >= lo) & (sample_t < hi)
pos_mask = bridge.corrupt_mask & sample_mask[:, None]
denom = pos_mask.float().sum()
if scalar(denom) <= 0.0:
continue
tag = f"{lo:.1f}_{min(hi, 1.0):.1f}".replace(".", "p")
correct_bin = ((pred == bridge.ids) & pos_mask).float().sum()
stats[f"acc_corrupt_t_{tag}"] = scalar(correct_bin / denom.clamp_min(1.0))
stats[f"acc_corrupt_t_{tag}__weight"] = scalar(denom)
stats[f"corrupt_frac_t_{tag}"] = scalar(denom / total_corrupt)
stats[f"corrupt_frac_t_{tag}__weight"] = scalar(corrupt_count)
model_for_stats = unwrap_model(trainable_model)
out_linear = getattr(getattr(model_for_stats, "output_layer", None), "linear", None)
if out_linear is None:
out_linear = getattr(model_for_stats, "out_proj", None)
if out_linear is not None:
out_w = out_linear.weight.detach().float()
out_g = out_linear.weight.grad
stats["out_w_norm"] = float(out_w.norm().cpu())
stats["out_g_norm"] = float(out_g.detach().float().norm().cpu()) if out_g is not None else 0.0
if heavy_stats:
stats.update(summarize_batch(
loss_logits.detach(),
bridge.ids,
bridge.attn_mask,
bridge.corrupt_mask,
target_probs=bridge.target_probs if args.target_loss == "soft_ce" else None,
include_loss=True,
))
if args.log_t_bin_stats:
stats.update(summarize_t_bin_acc(
loss_logits.detach(),
bridge.ids,
bridge.corrupt_mask,
bridge.t.detach(),
edges=t_bin_edges,
))
stats.update({
"init_gold_top10": float(masked_topk_acc(loss_state.detach(), bridge.ids, bridge.corrupt_mask, 10).cpu()),
"init_gold_top100": float(masked_topk_acc(loss_state.detach(), bridge.ids, bridge.corrupt_mask, 100).cpu()),
})
if rollout_train_sample_mask is not None:
applied_pos = bridge.corrupt_mask & rollout_train_sample_mask.view(-1, 1)
kept_pos = bridge.corrupt_mask & (~rollout_train_sample_mask.view(-1, 1))
stats.update({
"rollout_applied_pos_frac": float((applied_pos.float().sum() / bridge.corrupt_mask.float().sum().clamp_min(1.0)).detach().cpu()),
"init_acc_rollout_applied": float(masked_acc(loss_state.detach(), bridge.ids, applied_pos).cpu()),
"init_acc_rollout_kept": float(masked_acc(loss_state.detach(), bridge.ids, kept_pos).cpu()),
"logit_acc_rollout_applied": float(masked_acc(loss_logits.detach(), bridge.ids, applied_pos).cpu()),
"logit_acc_rollout_kept": float(masked_acc(loss_logits.detach(), bridge.ids, kept_pos).cpu()),
})
running.append(stats)
if args.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step()
scheduler.step()
if ema_state is not None and step >= args.ema_start_step:
update_ema_state(ema_state, unwrap_model(trainable_model), args.ema_decay)
optimizer.zero_grad(set_to_none=True)
if step % args.log_every == 0:
if rank_zero(rank):
avg = average_stats_window(running) if running else {}
elapsed = time.time() - start
print(" ".join([
f"step={step}",
f"micro_steps={step * args.grad_accum}",
f"elapsed={elapsed:.1f}s",
f"lr={scheduler.get_last_lr()[0]:.6e}",
] + [f"{k}={v:.4f}" for k, v in avg.items()]), flush=True)
running = []
start = time.time()
if rank_zero(rank) and args.eval_every > 0 and step % args.eval_every == 0:
run_demo(args, unwrap_model(trainable_model), tokenizer, last_batch, device)
if rank_zero(rank) and args.save_every > 0 and step % args.save_every == 0:
save_checkpoint(save_dir / f"step_{step:07d}.pt", unwrap_model(trainable_model), optimizer, args, tokenizer, step, scheduler, ema_state)
save_checkpoint(save_dir / "latest.pt", unwrap_model(trainable_model), optimizer, args, tokenizer, step, scheduler, ema_state)
elif rank_zero(rank) and args.latest_every > 0 and step % args.latest_every == 0:
save_checkpoint(save_dir / "latest.pt", unwrap_model(trainable_model), optimizer, args, tokenizer, step, scheduler, ema_state)
if rank_zero(rank):
save_checkpoint(save_dir / "latest.pt", unwrap_model(trainable_model), optimizer, args, tokenizer, args.total_steps, scheduler, ema_state)
if ddp:
dist.destroy_process_group()
if __name__ == "__main__":
main()