| |
| from __future__ import annotations |
|
|
| import argparse |
| import inspect |
| import json |
| import math |
| import os |
| import random |
| import time |
| from contextlib import nullcontext |
| from pathlib import Path |
|
|
| import torch |
| import torch.distributed as dist |
| import torch.nn.functional as F |
| from torch.nn.parallel import DistributedDataParallel |
| from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sampler |
|
|
| from flowtext_lab.bridges import make_dirichlet_bridge_batch, make_prob_bridge_batch |
|
|
| try: |
| from flowtext_lab.bridges import make_gaussian_bridge_batch |
| except ImportError: |
| make_gaussian_bridge_batch = None |
| from flowtext_lab.data import ( |
| CachedWrappedTextSequenceDataset, |
| EosPadCollator, |
| FixedPadCollator, |
| RecordPadTruncateTextSequenceDataset, |
| ShuffledWrappedStreamingTextSequenceDataset, |
| StreamingTextSequenceDataset, |
| TextSequenceDataset, |
| WrappedStreamingTextSequenceDataset, |
| ) |
| from flowtext_lab.decode import fill_blank_init, soft_residual_decode |
| from flowtext_lab.metrics import ( |
| masked_acc, |
| masked_ce, |
| masked_meanflow_l2, |
| masked_soft_ce, |
| masked_topk_acc, |
| summarize_batch, |
| summarize_t_bin_acc, |
| ) |
| from flowtext_lab.model import EndpointPredictor |
| from flowtext_lab.text_detokenization import infer_detokenizer_name |
| from flowtext_lab.tokenization import BpeTextTokenizer |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| p = argparse.ArgumentParser() |
| p.add_argument("--data_path", required=True) |
| p.add_argument("--text_column", default=None) |
| p.add_argument("--txt_record_mode", choices=["auto", "line", "eot"], default="auto") |
| p.add_argument("--detokenizer", default="auto") |
| p.add_argument("--openwebtext_split", choices=["all", "train_minus_100k", "valid_last_100k"], default="all") |
| p.add_argument("--tokenizer_path", required=True) |
| p.add_argument("--save_dir", required=True) |
| p.add_argument("--max_records", type=int, default=0) |
| p.add_argument( |
| "--elf_conditional_hf", |
| action="store_true", |
| help="Load ELF-style seq2seq Arrow datasets with condition_input_ids/input_ids columns.", |
| ) |
| p.add_argument("--eval_data_path", default="") |
| p.add_argument("--dataset_cache_dir", default="") |
| p.add_argument("--max_input_len", type=int, default=64) |
| p.add_argument("--conditional_pad_token", choices=["pad", "eos"], default="eos") |
| p.add_argument("--label_drop_prob", type=float, default=0.0) |
| p.add_argument("--streaming", action="store_true") |
| p.add_argument( |
| "--record_pad_truncate", |
| action="store_true", |
| help=( |
| "ELF-style OWT input pipeline: one record -> one example, no global " |
| "stream packing, truncate to max_len and fixed-pad in the collator." |
| ), |
| ) |
| p.add_argument( |
| "--record_add_eos", |
| action="store_true", |
| help="Append tokenizer EOS/SEP to each record in --record_pad_truncate mode.", |
| ) |
| p.add_argument( |
| "--record_add_special_tokens", |
| action="store_true", |
| help="Let the tokenizer add model special tokens in --record_pad_truncate mode.", |
| ) |
| p.add_argument( |
| "--record_pad_token", |
| choices=["pad", "eos"], |
| default="pad", |
| help="Pad id for --record_pad_truncate. ELF-style uses tokenizer PAD when available.", |
| ) |
| p.add_argument( |
| "--record_shuffle_buffer", |
| type=int, |
| default=10000, |
| help="Bounded online shuffle buffer for --record_pad_truncate; 0 disables it.", |
| ) |
| p.add_argument("--wrap", action="store_true", help="Use wrapped packing: [BOS] + payload + [EOS].") |
| p.add_argument( |
| "--wrap_mode", |
| choices=["stream", "record"], |
| default="stream", |
| help="stream keeps the original fixed-width token stream; record packs whole records and only truncates overlength records.", |
| ) |
| p.add_argument("--wrap_record_buffer_size", type=int, default=200) |
| p.add_argument( |
| "--owt_cached_chunks", |
| action="store_true", |
| help="OWT-only FLM/Duo-style cached chunk pool with epoch-level sampler shuffle.", |
| ) |
| p.add_argument( |
| "--owt_chunk_cache_dir", |
| default="", |
| help="Cache directory for --owt_cached_chunks. Stores meta.json and chunks.i32.bin.", |
| ) |
| p.add_argument("--owt_chunk_cache_rebuild", action="store_true") |
| p.add_argument("--owt_chunk_cache_write_batch", type=int, default=4096) |
| p.add_argument( |
| "--owt_exact_repeat_per_chunk", |
| type=int, |
| default=0, |
| help=( |
| "Pilot-only cached OWT sampler: repeat each cached chunk exactly this " |
| "many times, shuffle once deterministically, then shard across ranks." |
| ), |
| ) |
| p.add_argument( |
| "--online_chunk_shuffle", |
| action="store_true", |
| help="Use an online bounded shuffle buffer over wrapped stream chunks.", |
| ) |
| p.add_argument("--online_chunk_shuffle_buffer", type=int, default=10000) |
| p.add_argument("--max_len", type=int, default=128) |
| p.add_argument("--stride", type=int, default=0) |
| p.add_argument("--batch_size", type=int, default=8) |
| p.add_argument("--num_workers", type=int, default=0) |
| p.add_argument("--dataloader_prefetch_factor", type=int, default=2) |
| p.add_argument("--blocking_data_transfer", action="store_true") |
| p.add_argument("--global_batch_size", type=int, default=0, help="If >0, overrides grad_accum with ceil(global_batch_size / batch_size).") |
| p.add_argument("--grad_accum", type=int, default=1) |
| p.add_argument("--total_steps", type=int, default=1000) |
| p.add_argument("--lr", type=float, default=6e-4) |
| p.add_argument("--weight_decay", type=float, default=0.1) |
| p.add_argument( |
| "--output_weight_decay", |
| type=float, |
| default=-1.0, |
| help="If >=0, use this AdamW weight decay only for the output projection weight.", |
| ) |
| p.add_argument("--adam_beta1", type=float, default=0.9) |
| p.add_argument("--adam_beta2", type=float, default=0.95) |
| p.add_argument("--adam_eps", type=float, default=1e-8) |
| p.add_argument("--optimizer", choices=["adamw", "muon"], default="adamw") |
| p.add_argument("--muon_momentum", type=float, default=0.95) |
| p.add_argument("--muon_ns_steps", type=int, default=5) |
| p.add_argument("--muon_update_scale", type=float, default=1.0) |
| p.add_argument("--ema_decay", type=float, default=0.0) |
| p.add_argument("--ema_start_step", type=int, default=0) |
| p.add_argument("--warmup_steps", type=int, default=2000) |
| p.add_argument("--lr_schedule", choices=["constant_warmup", "cosine"], default="cosine") |
| p.add_argument("--min_lr", type=float, default=6e-5) |
| p.add_argument( |
| "--adamw_param_groups", |
| choices=["nanogpt", "all_decay"], |
| default="nanogpt", |
| help="nanoGPT decays only 2D matmul/embedding tensors; all_decay applies weight decay to every parameter.", |
| ) |
| p.add_argument("--grad_clip", type=float, default=1.0) |
| p.add_argument("--seed", type=int, default=42) |
| p.add_argument("--d_model", type=int, default=384) |
| p.add_argument("--cond_dim", type=int, default=128) |
| p.add_argument("--n_layers", type=int, default=6) |
| p.add_argument("--n_heads", type=int, default=6) |
| p.add_argument("--dim_ff", type=int, default=1536) |
| p.add_argument("--dropout", type=float, default=0.0) |
| p.add_argument( |
| "--output_bias", |
| action=argparse.BooleanOptionalAction, |
| default=False, |
| help="Whether to include a vocabulary bias in the final output projection. Default false to avoid global-frequency basin amplification.", |
| ) |
| p.add_argument( |
| "--norm_type", |
| choices=["rmsnorm", "layernorm"], |
| default="rmsnorm", |
| help="Normalization used by DDiT blocks/final head. Default rmsnorm follows modern ELF/T5-style practice.", |
| ) |
| p.add_argument( |
| "--vocab_size_override", |
| type=int, |
| default=0, |
| help="Debug-only: train with a compact/remapped vocabulary of this size instead of tokenizer.vocab_size.", |
| ) |
| p.add_argument("--model_type", choices=["transformer", "ddit"], default="transformer") |
| p.add_argument("--state_format", choices=["logprob", "prob"], default="logprob") |
| p.add_argument("--bridge", choices=["prob", "dirichlet", "gaussian"], default="prob") |
| p.add_argument("--target_loss", choices=["hard_ce", "soft_ce"], default="soft_ce") |
| p.add_argument("--meanflow_weight", type=float, default=0.0) |
| p.add_argument( |
| "--loss_t_weight_mode", |
| choices=["none", "linear_t", "quadratic_t", "one_minus_t_floor", "drop_low_t"], |
| default="none", |
| help="Per-sample t loss weighting. one_minus_t_floor downweights easy high-t targets.", |
| ) |
| p.add_argument("--loss_t_min_weight", type=float, default=0.0) |
| p.add_argument("--loss_t_drop_below", type=float, default=0.2) |
| p.add_argument( |
| "--prior_center_loss_beta", |
| type=float, |
| default=0.0, |
| help="If >0, train CE on logits - beta * logits(prior_state, t), reducing unconditional prior attraction.", |
| ) |
| p.add_argument("--prior_center_state", choices=["uniform"], default="uniform") |
| p.add_argument( |
| "--rollout_train_prob", |
| type=float, |
| default=0.0, |
| help=( |
| "Probability of replacing the supervised input state with a detached " |
| "self-rollout state before computing the reconstruction loss." |
| ), |
| ) |
| p.add_argument( |
| "--rollout_train_steps", |
| type=int, |
| default=1, |
| help="Number of no-grad self-rollout updates before the supervised loss.", |
| ) |
| p.add_argument( |
| "--rollout_train_infer_steps", |
| type=int, |
| default=64, |
| help="Decode horizon used to set the one-step flowmap gamma during rollout training.", |
| ) |
| p.add_argument( |
| "--rollout_train_temp", |
| type=float, |
| default=1.45, |
| help="Endpoint temperature used in the no-grad rollout-training update.", |
| ) |
| p.add_argument( |
| "--rollout_train_max_gamma", |
| type=float, |
| default=1.0, |
| help="Clamp for the no-grad rollout-training flowmap gamma. Set <=0 to disable the clamp.", |
| ) |
| p.add_argument( |
| "--rollout_train_corrupt_only", |
| action=argparse.BooleanOptionalAction, |
| default=True, |
| help="Only replace corrupted positions with the self-rollout state, preserving clean anchors.", |
| ) |
| p.add_argument( |
| "--rollout_train_samplewise", |
| action=argparse.BooleanOptionalAction, |
| default=False, |
| help=( |
| "Apply rollout stochastically per sample instead of per micro-batch. " |
| "This always computes the rollout forward when prob>0, which keeps " |
| "Tensor Core work steady while preserving the requested rollout mix." |
| ), |
| ) |
| p.add_argument( |
| "--rollout_train_compute_always", |
| action=argparse.BooleanOptionalAction, |
| default=False, |
| help=( |
| "For batchwise rollout, always compute the detached rollout forward " |
| "even when the batch coin decides not to apply it. This preserves " |
| "the old rollout objective while increasing steady Tensor Core work." |
| ), |
| ) |
| p.add_argument( |
| "--rollout_train_debug_return_base", |
| action=argparse.BooleanOptionalAction, |
| default=False, |
| help=( |
| "Debug only: still run the detached rollout forward, but return the " |
| "original bridge state to isolate rollout-forward side effects from " |
| "the rollout-state training distribution." |
| ), |
| ) |
| p.add_argument("--target_prob", type=float, default=0.99) |
| p.add_argument("--min_t", type=float, default=0.0) |
| p.add_argument("--max_t", type=float, default=1.0) |
| p.add_argument( |
| "--t_sampling_mode", |
| choices=["uniform", "power_low", "power_high", "logit_uniform"], |
| default="uniform", |
| help=( |
| "How to sample model/flow time t. power_low uses u**gamma to " |
| "oversample low t; power_high uses 1-(1-u)**gamma to oversample high t; " |
| "logit_uniform samples uniformly in logit(t)." |
| ), |
| ) |
| p.add_argument("--t_sampling_power", type=float, default=1.0) |
| p.add_argument("--t_sampling_eps", type=float, default=1e-4) |
| p.add_argument("--dual_t", action="store_true", help="Use separate model/flow time and corruption/support time.") |
| p.add_argument("--corrupt_t_mode", choices=["same", "independent", "constant"], default="independent") |
| p.add_argument("--corrupt_t_value", type=float, default=0.0) |
| p.add_argument("--corrupt_min_t", type=float, default=None) |
| p.add_argument("--corrupt_max_t", type=float, default=None) |
| p.add_argument( |
| "--prefix_block_prob", |
| type=float, |
| default=0.0, |
| help="Train on grow-context states: expose only prefix plus one active block, corrupting the active block and masking future tokens.", |
| ) |
| p.add_argument("--prefix_block_len", type=int, default=128) |
| p.add_argument("--min_mask_ratio", type=float, default=0.1) |
| p.add_argument("--max_mask_ratio", type=float, default=1.0) |
| p.add_argument( |
| "--mask_ratio_floor_schedule", |
| choices=["none", "one_minus_t"], |
| default="none", |
| help=( |
| "Optional per-sample lower-bound schedule for mask ratio. " |
| "one_minus_t uses max(min_mask_ratio, 1-t), so t=0 is all-mask " |
| "while high-t reverts to the configured min_mask_ratio." |
| ), |
| ) |
| p.add_argument( |
| "--mask_mixture_original_prob", |
| type=float, |
| default=0.0, |
| help="If any mask-mixture prob is >0, probability of using the original uniform mask-ratio sampler.", |
| ) |
| p.add_argument( |
| "--mask_mixture_lowk_prob", |
| type=float, |
| default=0.0, |
| help="If any mask-mixture prob is >0, probability of keeping only K clean tokens and corrupting the rest.", |
| ) |
| p.add_argument( |
| "--mask_mixture_lowcorrupt_prob", |
| type=float, |
| default=0.0, |
| help="If any mask-mixture prob is >0, probability of corrupting only K valid tokens and keeping the rest anchored.", |
| ) |
| p.add_argument( |
| "--mask_mixture_block_prob", |
| type=float, |
| default=0.0, |
| help="If any mask-mixture prob is >0, probability of corrupting one contiguous valid-token block and keeping the rest anchored.", |
| ) |
| p.add_argument( |
| "--mask_mixture_all_prob", |
| type=float, |
| default=0.0, |
| help="If any mask-mixture prob is >0, probability of corrupting every valid token.", |
| ) |
| p.add_argument( |
| "--mask_mixture_lowk_clean_tokens", |
| default="1,2,4,8,16,32,64", |
| help="Comma-separated clean-token counts sampled uniformly for the low-K mixture branch.", |
| ) |
| p.add_argument( |
| "--mask_mixture_lowcorrupt_tokens", |
| type=str, |
| default="1,2,4,8,16,32,64", |
| help="Comma-separated corrupt-token counts sampled uniformly for the low-corrupt-K mixture branch.", |
| ) |
| p.add_argument( |
| "--mask_mixture_block_tokens", |
| type=str, |
| default="64,128", |
| help="Comma-separated contiguous block lengths sampled uniformly for the block-corrupt mixture branch.", |
| ) |
| p.add_argument( |
| "--clean_state_mode", |
| choices=["onehot", "bridge"], |
| default="onehot", |
| help="State for non-corrupted support tokens: exact one-hot gold, or a same-t bridge sample around gold.", |
| ) |
| p.add_argument("--wrong_token_replace_prob", default="0.0") |
| p.add_argument("--wrong_token_schedule", choices=["constant", "hard", "linear_t", "exp_t", "exp_k1"], default="constant") |
| p.add_argument("--wrong_token_exp_k", type=float, default=1.0) |
| p.add_argument("--dirichlet_concentration_min", type=float, default=1.0) |
| p.add_argument("--dirichlet_concentration_max", type=float, default=1024.0) |
| p.add_argument( |
| "--dirichlet_endpoint_mode", |
| choices=["bernoulli_wrong", "dual_t_line", "categorical_dual_t"], |
| default="bernoulli_wrong", |
| ) |
| p.add_argument("--dirichlet_semantic_t_mode", choices=["same", "independent", "constant"], default="same") |
| p.add_argument("--dirichlet_semantic_t_value", type=float, default=0.0) |
| p.add_argument( |
| "--dirichlet_semantic_t_curve", |
| choices=["linear", "logit_power"], |
| default="linear", |
| help=( |
| "Transform semantic t before categorical_dual_t Bernoulli gold sampling. " |
| "logit_power uses p=t^gamma/(t^gamma+(1-t)^gamma)." |
| ), |
| ) |
| p.add_argument( |
| "--dirichlet_semantic_t_power", |
| type=float, |
| default=1.0, |
| help="Gamma for --dirichlet_semantic_t_curve logit_power.", |
| ) |
| p.add_argument( |
| "--endpoint_sequence_random_prob_alpha", |
| type=float, |
| default=0.0, |
| help=( |
| "For categorical_dual_t: with probability alpha*(1-t), force all " |
| "corrupted positions in a sequence to use the random endpoint branch. " |
| "0 disables the sequence-level gate." |
| ), |
| ) |
| p.add_argument( |
| "--categorical_wrong_from_full_vocab", |
| action="store_true", |
| help="For categorical_dual_t only: sample the non-gold endpoint from the full vocab, allowing it to equal gold.", |
| ) |
| p.add_argument( |
| "--categorical_wrong_from_batch_valid_tokens", |
| action="store_true", |
| help="For categorical_dual_t only: sample the non-gold endpoint from valid tokens in the current batch.", |
| ) |
| p.add_argument( |
| "--categorical_wrong_basin_token_ids", |
| default="", |
| help=( |
| "For categorical_dual_t only: comma-separated token ids for a basin-bank wrong endpoint branch. " |
| "Used when any categorical_wrong_*_prob is >0." |
| ), |
| ) |
| p.add_argument("--categorical_wrong_basin_prob", type=float, default=0.0) |
| p.add_argument("--categorical_wrong_unigram_prob", type=float, default=0.0) |
| p.add_argument("--categorical_wrong_uniform_prob", type=float, default=0.0) |
| p.add_argument( |
| "--categorical_wrong_corpus_unigram_path", |
| default="", |
| help=( |
| "Optional torch/numpy file containing OWT corpus unigram counts or probabilities. " |
| "When set, categorical_wrong_unigram_prob samples from this corpus distribution." |
| ), |
| ) |
| p.add_argument( |
| "--categorical_wrong_corpus_unigram_alpha", |
| type=float, |
| default=1.0, |
| help="Exponent applied to corpus unigram counts before normalization; values <1 flatten high-frequency tokens.", |
| ) |
| p.add_argument( |
| "--categorical_wrong_basin_shared_prob", |
| type=float, |
| default=0.0, |
| help="Per-sequence probability of sharing one basin-bank wrong token across all corrupted positions.", |
| ) |
| p.add_argument( |
| "--categorical_wrong_unigram_shared_prob", |
| type=float, |
| default=0.0, |
| help="Per-sequence probability of sharing one corpus/batch-unigram wrong token across all corrupted positions.", |
| ) |
| p.add_argument( |
| "--simplex_bridge_sampler", |
| choices=["dirichlet", "logistic_normal_linear_mean"], |
| default="dirichlet", |
| help="How to sample simplex bridge states after choosing the endpoint.", |
| ) |
| p.add_argument("--logistic_normal_sigma_min", type=float, default=0.18) |
| p.add_argument("--logistic_normal_sigma_max", type=float, default=2.2) |
| p.add_argument("--logistic_normal_tau_min", type=float, default=0.65) |
| p.add_argument("--logistic_normal_tau_max", type=float, default=1.15) |
| p.add_argument("--dirichlet_scale", type=float, default=128.0) |
| p.add_argument("--dirichlet_floor", type=float, default=1e-3) |
| p.add_argument("--eps", type=float, default=1e-8) |
| p.add_argument("--log_every", type=int, default=10) |
| p.add_argument("--eval_every", type=int, default=100) |
| p.add_argument("--save_every", type=int, default=500) |
| p.add_argument("--latest_every", type=int, default=0, help="If >0, also refresh save_dir/latest.pt at this interval.") |
| p.add_argument("--resume_path", default="", help="Optional checkpoint path to resume model/optimizer/scheduler state.") |
| p.add_argument( |
| "--init_model_path", |
| default="", |
| help="Optional checkpoint path used only to initialize model weights; optimizer/scheduler/step start fresh.", |
| ) |
| p.add_argument( |
| "--init_pos_embed_mode", |
| choices=["strict", "repeat", "interpolate"], |
| default="strict", |
| help="When --init_model_path has a different pos_embed length, adapt it to the current max_len.", |
| ) |
| p.add_argument("--infer_steps", type=int, default=64) |
| p.add_argument("--decode_damping", type=float, default=1.0) |
| p.add_argument("--max_gamma", type=float, default=1.0) |
| p.add_argument("--decode_solver", choices=["simple", "flowmap"], default="flowmap") |
| p.add_argument("--noise_init", choices=["uniform", "logistic_normal"], default="logistic_normal") |
| p.add_argument("--bridge_noise_init", choices=["uniform", "logistic_normal"], default="logistic_normal") |
| p.add_argument("--noise_sigma", type=float, default=-1.0) |
| p.add_argument("--demo_mask_ratios", default="0.1,0.5,1.0") |
| p.add_argument("--demo_fill_t", type=float, default=0.0) |
| p.add_argument("--bf16", action="store_true") |
| p.add_argument("--allow_tf32", action=argparse.BooleanOptionalAction, default=True) |
| p.add_argument("--activation_checkpointing", action="store_true") |
| p.add_argument("--activation_checkpoint_interval", type=int, default=1) |
| p.add_argument("--activation_checkpoint_scope", choices=["block", "mlp"], default="block") |
| p.add_argument("--ddp_static_graph", action="store_true") |
| p.add_argument("--ddp_gradient_as_bucket_view", action=argparse.BooleanOptionalAction, default=True) |
| p.add_argument( |
| "--full_train_stats", |
| action="store_true", |
| help="Compute heavy train diagnostics every micro-step. Default computes cheap stats every optimizer step and heavy loss diagnostics on log steps.", |
| ) |
| p.add_argument( |
| "--log_train_stats_each_step", |
| action=argparse.BooleanOptionalAction, |
| default=True, |
| help="Accumulate lightweight train diagnostics from every micro-batch, then report weighted averages over each log window.", |
| ) |
| p.add_argument( |
| "--log_t_bin_stats", |
| action=argparse.BooleanOptionalAction, |
| default=True, |
| help="Log corrupt-token accuracy bucketed by model t, making mixed-t train logs easier to read.", |
| ) |
| p.add_argument( |
| "--log_t_bin_edges", |
| default="0,0.2,0.4,0.6,0.8,1.0", |
| help="Comma-separated t-bin edges used when --log_t_bin_stats is enabled.", |
| ) |
| p.add_argument("--torch_compile", action="store_true") |
| p.add_argument("--compile_mode", choices=["default", "reduce-overhead", "max-autotune"], default="max-autotune") |
| return p.parse_args() |
|
|
|
|
| def _datasetdict_to_single_split(ds): |
| if hasattr(ds, "keys") and not hasattr(ds, "column_names"): |
| for key in ("train", "validation", "test"): |
| if key in ds: |
| return ds[key] |
| return ds[next(iter(ds.keys()))] |
| return ds |
|
|
|
|
| def _load_arrow_dir_without_hf_metadata(path: str): |
| """Fallback for ELF datasets saved with newer `List` feature metadata. |
| |
| Some nodes have an older `datasets` build that cannot parse feature type |
| `List`, even though pyarrow can read the Arrow files just fine. We strip |
| only the HuggingFace schema metadata and let `datasets` infer equivalent |
| sequence columns from Arrow. |
| """ |
| import pyarrow as pa |
| import pyarrow.ipc as ipc |
| from datasets import Dataset |
|
|
| root = Path(path) |
| arrow_files = sorted(root.glob("data-*.arrow")) |
| if not arrow_files: |
| raise FileNotFoundError(f"No data-*.arrow files found under {path}") |
| tables = [] |
| for arrow_path in arrow_files: |
| with pa.memory_map(str(arrow_path), "r") as source: |
| table = ipc.open_stream(source).read_all() |
| tables.append(table.cast(table.schema.remove_metadata())) |
| table = tables[0] if len(tables) == 1 else pa.concat_tables(tables, promote_options="default") |
| return Dataset(table) |
|
|
|
|
| def load_elf_conditional_dataset(path: str, max_records: int = 0): |
| from datasets import load_from_disk |
|
|
| try: |
| ds = _datasetdict_to_single_split(load_from_disk(path)) |
| except ValueError as exc: |
| if "Feature type 'List' not found" not in str(exc): |
| raise |
| ds = _load_arrow_dir_without_hf_metadata(path) |
| if max_records and max_records > 0: |
| ds = ds.select(range(min(int(max_records), len(ds)))) |
| required = {"condition_input_ids", "input_ids"} |
| missing = sorted(required.difference(ds.column_names)) |
| if missing: |
| raise ValueError(f"ELF conditional dataset {path} missing columns: {missing}") |
| return ds |
|
|
|
|
| class ELFConditionalCollator: |
| def __init__(self, pad_id: int, max_len: int, max_input_len: int, *, loss_on_pad: bool) -> None: |
| self.pad_id = int(pad_id) |
| self.max_len = int(max_len) |
| self.max_input_len = int(max_input_len) |
| self.loss_on_pad = bool(loss_on_pad) |
|
|
| def __call__(self, examples: list[dict]) -> dict[str, torch.Tensor]: |
| ids_rows = [] |
| attn_rows = [] |
| loss_rows = [] |
| cond_rows = [] |
| for ex in examples: |
| cond = [int(x) for x in ex["condition_input_ids"][: self.max_input_len]] |
| target_budget = max(0, self.max_len - len(cond)) |
| target = [int(x) for x in ex["input_ids"][:target_budget]] |
| seq = (cond + target)[: self.max_len] |
| cond_len = min(len(cond), self.max_len) |
| target_len = max(0, len(seq) - cond_len) |
| pad_len = self.max_len - len(seq) |
| if pad_len > 0: |
| seq = seq + [self.pad_id] * pad_len |
|
|
| attn = [True] * (cond_len + target_len) + [False] * pad_len |
| cond_mask = [True] * cond_len + [False] * (self.max_len - cond_len) |
| if self.loss_on_pad: |
| loss_mask = [False] * cond_len + [True] * (self.max_len - cond_len) |
| else: |
| loss_mask = [False] * cond_len + [True] * target_len + [False] * pad_len |
| ids_rows.append(seq) |
| attn_rows.append(attn) |
| loss_rows.append(loss_mask) |
| cond_rows.append(cond_mask) |
| return { |
| "ids": torch.tensor(ids_rows, dtype=torch.long), |
| "attn_mask": torch.tensor(attn_rows, dtype=torch.bool), |
| "loss_mask": torch.tensor(loss_rows, dtype=torch.bool), |
| "cond_seq_mask": torch.tensor(cond_rows, dtype=torch.bool), |
| } |
|
|
|
|
| class ExactRepeatDistributedSampler(Sampler[int]): |
| """Deterministic finite schedule with exact per-item repeat counts.""" |
|
|
| def __init__( |
| self, |
| dataset_size: int, |
| repeats_per_item: int, |
| *, |
| seed: int, |
| rank: int = 0, |
| world_size: int = 1, |
| ) -> None: |
| self.dataset_size = int(dataset_size) |
| self.repeats_per_item = int(repeats_per_item) |
| self.seed = int(seed) |
| self.rank = int(rank) |
| self.world_size = int(world_size) |
| if self.dataset_size <= 0: |
| raise ValueError("dataset_size must be positive") |
| if self.repeats_per_item <= 0: |
| raise ValueError("repeats_per_item must be positive") |
| total = self.dataset_size * self.repeats_per_item |
| if total % self.world_size != 0: |
| raise ValueError( |
| f"exact repeat schedule total={total} is not divisible by world_size={self.world_size}" |
| ) |
| self.num_samples = total // self.world_size |
|
|
| def __iter__(self): |
| import torch |
|
|
| indices = torch.arange(self.dataset_size, dtype=torch.long).repeat_interleave( |
| self.repeats_per_item |
| ) |
| g = torch.Generator() |
| g.manual_seed(self.seed) |
| indices = indices[torch.randperm(indices.numel(), generator=g)] |
| return iter(indices[self.rank :: self.world_size].tolist()) |
|
|
| def __len__(self) -> int: |
| return self.num_samples |
|
|
|
|
| def seed_all(seed: int) -> None: |
| random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def sample_time( |
| batch: int, |
| min_t: float, |
| max_t: float, |
| device: torch.device, |
| mode: str = "uniform", |
| power: float = 1.0, |
| eps: float = 1e-4, |
| ) -> torch.Tensor: |
| u = torch.rand(batch, device=device) |
| if mode == "uniform": |
| z = u |
| elif mode == "power_low": |
| z = u.pow(float(power)) |
| elif mode == "power_high": |
| z = 1.0 - (1.0 - u).pow(float(power)) |
| elif mode == "logit_uniform": |
| lo = torch.logit(torch.tensor(float(eps), device=device)) |
| hi = torch.logit(torch.tensor(1.0 - float(eps), device=device)) |
| z = torch.sigmoid(lo + (hi - lo) * u) |
| else: |
| raise ValueError(f"Unknown t_sampling_mode: {mode}") |
| return min_t + (max_t - min_t) * z |
|
|
|
|
| def sample_loss_t_weight(args: argparse.Namespace, model_t: torch.Tensor) -> torch.Tensor: |
| if args.loss_t_weight_mode == "none": |
| return torch.ones_like(model_t) |
| t = model_t.clamp(0.0, 1.0) |
| if args.loss_t_weight_mode == "linear_t": |
| weight = t |
| elif args.loss_t_weight_mode == "quadratic_t": |
| weight = t.square() |
| elif args.loss_t_weight_mode == "one_minus_t_floor": |
| weight = 1.0 - t |
| elif args.loss_t_weight_mode == "drop_low_t": |
| weight = torch.where(t < float(args.loss_t_drop_below), torch.zeros_like(t), torch.ones_like(t)) |
| else: |
| raise ValueError(f"Unknown loss_t_weight_mode: {args.loss_t_weight_mode}") |
| if args.loss_t_min_weight > 0: |
| weight = weight.clamp_min(float(args.loss_t_min_weight)) |
| return weight |
|
|
|
|
| def parse_t_bin_edges(raw: str) -> tuple[float, ...]: |
| vals = tuple(float(x) for x in raw.split(",") if x.strip()) |
| if len(vals) < 2: |
| raise ValueError("--log_t_bin_edges must contain at least two comma-separated values") |
| if any(vals[i] >= vals[i + 1] for i in range(len(vals) - 1)): |
| raise ValueError(f"--log_t_bin_edges must be strictly increasing, got {raw!r}") |
| if vals[0] > 0.0 or vals[-1] < 1.0: |
| raise ValueError(f"--log_t_bin_edges must cover [0, 1], got {raw!r}") |
| return vals |
|
|
|
|
| def average_stats_window(running: list[dict[str, float]]) -> dict[str, float]: |
| keys: list[str] = [] |
| seen: set[str] = set() |
| for stats in running: |
| for key in stats: |
| if key.endswith("__weight"): |
| continue |
| if key not in seen: |
| seen.add(key) |
| keys.append(key) |
| avg: dict[str, float] = {} |
| for key in keys: |
| weighted_sum = 0.0 |
| weight_sum = 0.0 |
| vals: list[float] = [] |
| weight_key = f"{key}__weight" |
| for stats in running: |
| if key not in stats: |
| continue |
| val = float(stats[key]) |
| if not math.isfinite(val): |
| continue |
| if weight_key in stats: |
| weight = float(stats[weight_key]) |
| if math.isfinite(weight) and weight > 0.0: |
| weighted_sum += val * weight |
| weight_sum += weight |
| else: |
| vals.append(val) |
| if weight_sum > 0.0: |
| avg[key] = weighted_sum / weight_sum |
| elif vals: |
| avg[key] = sum(vals) / len(vals) |
| return avg |
|
|
|
|
| def weighted_masked_ce(logits: torch.Tensor, ids: torch.Tensor, mask: torch.Tensor, sample_weight: torch.Tensor) -> torch.Tensor: |
| per = torch.nn.functional.cross_entropy(logits.flatten(0, 1), ids.flatten(), reduction="none").view_as(ids) |
| weight = mask.float() * sample_weight.float().view(-1, 1) |
| return (per * weight).sum() / weight.sum().clamp_min(1.0) |
|
|
|
|
| def weighted_masked_soft_ce( |
| logits: torch.Tensor, |
| target_probs: torch.Tensor, |
| mask: torch.Tensor, |
| sample_weight: torch.Tensor, |
| ) -> torch.Tensor: |
| log_probs = torch.nn.functional.log_softmax(logits, dim=-1) |
| per = -(target_probs.to(dtype=log_probs.dtype) * log_probs).sum(dim=-1).float() |
| weight = mask.float() * sample_weight.float().view(-1, 1) |
| return (per * weight).sum() / weight.sum().clamp_min(1.0) |
|
|
|
|
| def make_prior_center_state( |
| args: argparse.Namespace, |
| batch: int, |
| length: int, |
| vocab_size: int, |
| device: torch.device, |
| ) -> torch.Tensor: |
| if args.prior_center_state != "uniform": |
| raise ValueError(f"Unknown prior_center_state: {args.prior_center_state}") |
| probs = torch.full((batch, length, vocab_size), 1.0 / vocab_size, dtype=torch.float32, device=device) |
| if args.state_format == "prob": |
| return probs |
| return torch.log(probs.clamp_min(args.eps)) |
|
|
|
|
| def state_to_probs(args: argparse.Namespace, state: torch.Tensor) -> torch.Tensor: |
| if args.state_format == "prob": |
| probs = state.float() |
| else: |
| probs = state.float().exp() |
| probs = probs.clamp_min(args.eps) |
| return probs / probs.sum(dim=-1, keepdim=True).clamp_min(args.eps) |
|
|
|
|
| def probs_to_state(args: argparse.Namespace, probs: torch.Tensor) -> torch.Tensor: |
| probs = probs.float().clamp_min(args.eps) |
| probs = probs / probs.sum(dim=-1, keepdim=True).clamp_min(args.eps) |
| if args.state_format == "prob": |
| return probs |
| return torch.log(probs.clamp_min(args.eps)) |
|
|
|
|
| def zero_param_anchor(model: torch.nn.Module, like: torch.Tensor) -> torch.Tensor: |
| anchor = like.new_zeros(()) |
| for p in model.parameters(): |
| if p.requires_grad: |
| anchor = anchor + p.reshape(-1)[0].float() * 0.0 |
| return anchor |
|
|
|
|
| @torch.no_grad() |
| def maybe_make_rollout_train_state( |
| args: argparse.Namespace, |
| model: torch.nn.Module, |
| base_state: torch.Tensor, |
| model_t: torch.Tensor, |
| attn_mask: torch.Tensor, |
| corrupt_mask: torch.Tensor, |
| ) -> tuple[torch.Tensor, float, torch.Tensor | None]: |
| prob = float(args.rollout_train_prob) |
| steps = int(args.rollout_train_steps) |
| if prob <= 0.0 or steps <= 0: |
| return base_state, 0.0, None |
| samplewise = bool(getattr(args, "rollout_train_samplewise", False)) |
| if samplewise: |
| apply_sample = torch.rand(base_state.size(0), device=model_t.device) < prob |
| applied = float(apply_sample.float().mean().detach().cpu()) |
| else: |
| apply_batch = prob >= 1.0 or float(torch.rand((), device=model_t.device).item()) < prob |
| if not apply_batch and not bool(getattr(args, "rollout_train_compute_always", False)): |
| return base_state, 0.0, None |
| apply_sample = torch.full((base_state.size(0),), bool(apply_batch), device=model_t.device) |
| applied = 1.0 if apply_batch else 0.0 |
|
|
| probs = state_to_probs(args, base_state) |
| original_probs = probs |
| rollout_t = model_t.float().clamp(0.0, 1.0) |
| dt = 1.0 / max(1, int(args.rollout_train_infer_steps)) |
| with torch.no_grad(): |
| for _ in range(steps): |
| logits = model(probs_to_state(args, probs), rollout_t, attn_mask) |
| endpoint = torch.nn.functional.softmax(logits / float(args.rollout_train_temp), dim=-1) |
| del logits |
| gamma = dt / (1.0 - rollout_t).clamp_min(args.eps) |
| if float(args.rollout_train_max_gamma) > 0: |
| gamma = gamma.clamp_max(float(args.rollout_train_max_gamma)) |
| probs = probs + gamma.view(-1, 1, 1) * (endpoint - probs) |
| probs = probs.clamp_min(args.eps) |
| probs = probs / probs.sum(dim=-1, keepdim=True).clamp_min(args.eps) |
| rollout_t = (rollout_t + dt).clamp_max(1.0) |
|
|
| if bool(getattr(args, "rollout_train_debug_return_base", False)): |
| return base_state, float(applied), apply_sample.detach() |
|
|
| if applied <= 0.0: |
| return base_state, 0.0, apply_sample.detach() |
| if args.rollout_train_corrupt_only: |
| apply_pos = corrupt_mask |
| apply_pos = apply_pos & apply_sample.view(-1, 1) |
| probs = torch.where(apply_pos.unsqueeze(-1), probs, original_probs) |
| else: |
| probs = torch.where(apply_sample.view(-1, 1, 1), probs, original_probs) |
| return probs_to_state(args, probs).detach(), float(applied), apply_sample.detach() |
|
|
|
|
| @torch.no_grad() |
| def zeropower_via_newtonschulz5(g: torch.Tensor, steps: int = 5) -> torch.Tensor: |
| """Muon Newton-Schulz orthogonalization for matrix-shaped updates.""" |
| orig_shape = g.shape |
| x = g.float().reshape(g.shape[0], -1) |
| transposed = False |
| if x.size(0) > x.size(1): |
| x = x.t() |
| transposed = True |
| x = x / (x.norm() + 1e-7) |
| a, b, c = 3.4445, -4.7750, 2.0315 |
| for _ in range(max(1, int(steps))): |
| xx_t = x @ x.t() |
| x = a * x + (b * xx_t + c * (xx_t @ xx_t)) @ x |
| if transposed: |
| x = x.t() |
| return x.reshape(orig_shape).to(dtype=g.dtype) |
|
|
|
|
| class Muon(torch.optim.Optimizer): |
| """Minimal Muon optimizer with Adam fallback for 1D params.""" |
|
|
| def __init__( |
| self, |
| params, |
| *, |
| lr: float, |
| momentum: float = 0.95, |
| ns_steps: int = 5, |
| adam_betas: tuple[float, float] = (0.9, 0.95), |
| eps: float = 1e-8, |
| weight_decay: float = 0.0, |
| update_scale: float = 1.0, |
| ) -> None: |
| params = [p for p in params if p.requires_grad] |
| muon_params = [p for p in params if p.dim() >= 2] |
| adam_params = [p for p in params if p.dim() < 2] |
| defaults = dict( |
| lr=lr, |
| momentum=momentum, |
| ns_steps=ns_steps, |
| adam_betas=adam_betas, |
| eps=eps, |
| weight_decay=weight_decay, |
| update_scale=update_scale, |
| ) |
| groups = [] |
| if muon_params: |
| groups.append({"params": muon_params, "use_muon": True, **defaults}) |
| if adam_params: |
| groups.append({"params": adam_params, "use_muon": False, **defaults}) |
| super().__init__(groups, defaults) |
|
|
| @torch.no_grad() |
| def step(self, closure=None): |
| loss = None |
| if closure is not None: |
| with torch.enable_grad(): |
| loss = closure() |
| for group in self.param_groups: |
| lr = group["lr"] |
| wd = group["weight_decay"] |
| if group.get("use_muon", False): |
| momentum = group["momentum"] |
| ns_steps = group["ns_steps"] |
| update_scale = group["update_scale"] |
| for p in group["params"]: |
| if p.grad is None: |
| continue |
| g = p.grad |
| if g.is_sparse: |
| raise RuntimeError("Muon does not support sparse gradients") |
| state = self.state[p] |
| if not state: |
| state["momentum_buffer"] = torch.zeros_like(p) |
| buf = state["momentum_buffer"] |
| buf.mul_(momentum).add_(g, alpha=1.0 - momentum) |
| update = zeropower_via_newtonschulz5(buf, ns_steps) |
| if wd: |
| p.mul_(1.0 - lr * wd) |
| p.add_(update, alpha=-lr * update_scale) |
| else: |
| beta1, beta2 = group["adam_betas"] |
| eps = group["eps"] |
| for p in group["params"]: |
| if p.grad is None: |
| continue |
| g = p.grad |
| if g.is_sparse: |
| raise RuntimeError("Adam fallback does not support sparse gradients") |
| state = self.state[p] |
| if not state: |
| state["step"] = 0 |
| state["exp_avg"] = torch.zeros_like(p) |
| state["exp_avg_sq"] = torch.zeros_like(p) |
| exp_avg = state["exp_avg"] |
| exp_avg_sq = state["exp_avg_sq"] |
| state["step"] += 1 |
| step = state["step"] |
| if wd: |
| p.mul_(1.0 - lr * wd) |
| exp_avg.mul_(beta1).add_(g, alpha=1.0 - beta1) |
| exp_avg_sq.mul_(beta2).addcmul_(g, g, value=1.0 - beta2) |
| bias_correction1 = 1.0 - beta1**step |
| bias_correction2 = 1.0 - beta2**step |
| step_size = lr * (bias_correction2**0.5) / bias_correction1 |
| denom = exp_avg_sq.sqrt().add_(eps) |
| p.addcdiv_(exp_avg, denom, value=-step_size) |
| return loss |
|
|
|
|
| def configure_adamw_optimizer( |
| model: torch.nn.Module, |
| args: argparse.Namespace, |
| device: torch.device, |
| ) -> torch.optim.Optimizer: |
| named_params = [(name, p) for name, p in model.named_parameters() if p.requires_grad] |
| params = [p for _, p in named_params] |
| if args.optimizer == "muon": |
| return Muon( |
| params, |
| lr=args.lr, |
| momentum=args.muon_momentum, |
| ns_steps=args.muon_ns_steps, |
| adam_betas=(args.adam_beta1, args.adam_beta2), |
| eps=args.adam_eps, |
| weight_decay=args.weight_decay, |
| update_scale=args.muon_update_scale, |
| ) |
| fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters |
| extra_args = {"fused": True} if fused_available and device.type == "cuda" else {} |
| output_decay = float(getattr(args, "output_weight_decay", -1.0)) |
| output_weight_names = ("output_layer.linear.weight", "out_proj.weight") |
| output_params = [ |
| p |
| for name, p in named_params |
| if output_decay >= 0.0 and any(name.endswith(suffix) for suffix in output_weight_names) |
| ] |
| output_param_ids = {id(p) for p in output_params} |
| if args.adamw_param_groups == "all_decay": |
| if output_params: |
| other_params = [p for p in params if id(p) not in output_param_ids] |
| optim_groups = [ |
| {"params": other_params, "weight_decay": args.weight_decay}, |
| {"params": output_params, "weight_decay": output_decay}, |
| ] |
| return torch.optim.AdamW( |
| optim_groups, |
| lr=args.lr, |
| betas=(args.adam_beta1, args.adam_beta2), |
| eps=args.adam_eps, |
| **extra_args, |
| ) |
| return torch.optim.AdamW( |
| params, |
| lr=args.lr, |
| betas=(args.adam_beta1, args.adam_beta2), |
| eps=args.adam_eps, |
| weight_decay=args.weight_decay, |
| **extra_args, |
| ) |
|
|
| |
| decay_params = [p for _, p in named_params if p.dim() >= 2 and id(p) not in output_param_ids] |
| nodecay_params = [p for _, p in named_params if p.dim() < 2 and id(p) not in output_param_ids] |
| optim_groups = [] |
| if decay_params: |
| optim_groups.append({"params": decay_params, "weight_decay": args.weight_decay}) |
| if output_params: |
| optim_groups.append({"params": output_params, "weight_decay": output_decay}) |
| if nodecay_params: |
| optim_groups.append({"params": nodecay_params, "weight_decay": 0.0}) |
| return torch.optim.AdamW( |
| optim_groups, |
| lr=args.lr, |
| betas=(args.adam_beta1, args.adam_beta2), |
| eps=args.adam_eps, |
| **extra_args, |
| ) |
|
|
|
|
| @torch.no_grad() |
| def init_ema_state(model: torch.nn.Module) -> dict[str, torch.Tensor]: |
| return { |
| k: v.detach().clone() |
| for k, v in model.state_dict().items() |
| if torch.is_floating_point(v) |
| } |
|
|
|
|
| @torch.no_grad() |
| def update_ema_state(ema_state: dict[str, torch.Tensor], model: torch.nn.Module, decay: float) -> None: |
| for k, v in model.state_dict().items(): |
| if k in ema_state: |
| ema_state[k].mul_(decay).add_(v.detach(), alpha=1.0 - decay) |
|
|
|
|
| def resolve_bridge_force_t(args: argparse.Namespace, model_t: torch.Tensor) -> torch.Tensor | None: |
| if not args.dual_t or args.corrupt_t_mode == "same": |
| return model_t |
| if args.corrupt_t_mode == "constant": |
| return torch.full_like(model_t, float(args.corrupt_t_value)) |
| return None |
|
|
|
|
| def load_corpus_unigram_probs( |
| path: str, |
| vocab_size: int, |
| alpha: float, |
| device: torch.device, |
| ) -> torch.Tensor | None: |
| if not path: |
| return None |
| p = Path(path) |
| if not p.exists(): |
| raise FileNotFoundError(f"categorical wrong corpus unigram file not found: {p}") |
| if p.suffix == ".npy": |
| import numpy as np |
|
|
| data = torch.from_numpy(np.load(p)) |
| else: |
| loaded = torch.load(p, map_location="cpu") |
| if isinstance(loaded, dict): |
| for key in ("probs", "prob", "counts", "count", "unigram_counts"): |
| if key in loaded: |
| loaded = loaded[key] |
| break |
| data = torch.as_tensor(loaded) |
| probs = data.flatten().to(dtype=torch.float32, device="cpu") |
| if probs.numel() < vocab_size: |
| probs = F.pad(probs, (0, vocab_size - probs.numel())) |
| elif probs.numel() > vocab_size: |
| probs = probs[:vocab_size] |
| probs = probs.clamp_min(0.0) |
| if float(alpha) != 1.0: |
| probs = probs.pow(float(alpha)) |
| total = probs.sum() |
| if not bool(torch.isfinite(total).item()) or float(total.item()) <= 0.0: |
| raise ValueError(f"corpus unigram file has no positive finite mass: {p}") |
| return (probs / total).to(device=device) |
|
|
|
|
| def make_bridge( |
| args: argparse.Namespace, |
| ids: torch.Tensor, |
| attn_mask: torch.Tensor, |
| vocab_size: int, |
| force_t: torch.Tensor | None = None, |
| force_corrupt_mask: torch.Tensor | None = None, |
| categorical_wrong_corpus_unigram_probs: torch.Tensor | None = None, |
| ): |
| corrupt_min_t = args.min_t if args.corrupt_min_t is None else args.corrupt_min_t |
| corrupt_max_t = args.max_t if args.corrupt_max_t is None else args.corrupt_max_t |
| if args.bridge == "prob": |
| return make_prob_bridge_batch( |
| ids=ids, |
| attn_mask=attn_mask, |
| vocab_size=vocab_size, |
| target_prob=args.target_prob, |
| min_t=corrupt_min_t, |
| max_t=corrupt_max_t, |
| min_mask_ratio=args.min_mask_ratio, |
| max_mask_ratio=args.max_mask_ratio, |
| wrong_token_replace_prob=args.wrong_token_replace_prob, |
| wrong_token_schedule=args.wrong_token_schedule, |
| wrong_token_exp_k=args.wrong_token_exp_k, |
| eps=args.eps, |
| state_format=args.state_format, |
| noise_init=args.bridge_noise_init, |
| noise_sigma=args.noise_sigma, |
| force_t=force_t, |
| force_corrupt_mask=force_corrupt_mask, |
| mask_ratio_floor_schedule=args.mask_ratio_floor_schedule, |
| mask_mixture_original_prob=args.mask_mixture_original_prob, |
| mask_mixture_lowk_prob=args.mask_mixture_lowk_prob, |
| mask_mixture_lowcorrupt_prob=args.mask_mixture_lowcorrupt_prob, |
| mask_mixture_block_prob=args.mask_mixture_block_prob, |
| mask_mixture_all_prob=args.mask_mixture_all_prob, |
| mask_mixture_lowk_clean_tokens=args.mask_mixture_lowk_clean_tokens, |
| mask_mixture_lowcorrupt_tokens=args.mask_mixture_lowcorrupt_tokens, |
| mask_mixture_block_tokens=args.mask_mixture_block_tokens, |
| clean_state_mode=args.clean_state_mode, |
| ) |
| if args.bridge == "gaussian": |
| if make_gaussian_bridge_batch is None: |
| raise RuntimeError("bridge=gaussian requested, but make_gaussian_bridge_batch is unavailable") |
| return make_gaussian_bridge_batch( |
| ids=ids, |
| attn_mask=attn_mask, |
| vocab_size=vocab_size, |
| target_prob=args.target_prob, |
| min_t=corrupt_min_t, |
| max_t=corrupt_max_t, |
| eps=args.eps, |
| force_t=force_t, |
| ) |
| return make_dirichlet_bridge_batch( |
| ids=ids, |
| attn_mask=attn_mask, |
| vocab_size=vocab_size, |
| target_prob=args.target_prob, |
| min_t=corrupt_min_t, |
| max_t=corrupt_max_t, |
| min_mask_ratio=args.min_mask_ratio, |
| max_mask_ratio=args.max_mask_ratio, |
| wrong_token_replace_prob=args.wrong_token_replace_prob, |
| wrong_token_schedule=args.wrong_token_schedule, |
| wrong_token_exp_k=args.wrong_token_exp_k, |
| dirichlet_concentration_min=args.dirichlet_concentration_min, |
| dirichlet_concentration_max=args.dirichlet_concentration_max, |
| eps=args.eps, |
| state_format=args.state_format, |
| dirichlet_endpoint_mode=args.dirichlet_endpoint_mode, |
| dirichlet_semantic_t_mode=args.dirichlet_semantic_t_mode, |
| dirichlet_semantic_t_value=args.dirichlet_semantic_t_value, |
| dirichlet_semantic_t_curve=args.dirichlet_semantic_t_curve, |
| dirichlet_semantic_t_power=args.dirichlet_semantic_t_power, |
| endpoint_sequence_random_prob_alpha=args.endpoint_sequence_random_prob_alpha, |
| categorical_wrong_from_full_vocab=args.categorical_wrong_from_full_vocab, |
| categorical_wrong_from_batch_valid_tokens=args.categorical_wrong_from_batch_valid_tokens, |
| categorical_wrong_basin_token_ids=args.categorical_wrong_basin_token_ids, |
| categorical_wrong_basin_prob=args.categorical_wrong_basin_prob, |
| categorical_wrong_unigram_prob=args.categorical_wrong_unigram_prob, |
| categorical_wrong_uniform_prob=args.categorical_wrong_uniform_prob, |
| categorical_wrong_basin_shared_prob=args.categorical_wrong_basin_shared_prob, |
| categorical_wrong_unigram_shared_prob=args.categorical_wrong_unigram_shared_prob, |
| categorical_wrong_corpus_unigram_probs=categorical_wrong_corpus_unigram_probs, |
| simplex_bridge_sampler=args.simplex_bridge_sampler, |
| logistic_normal_sigma_min=args.logistic_normal_sigma_min, |
| logistic_normal_sigma_max=args.logistic_normal_sigma_max, |
| logistic_normal_tau_min=args.logistic_normal_tau_min, |
| logistic_normal_tau_max=args.logistic_normal_tau_max, |
| force_t=force_t, |
| force_corrupt_mask=force_corrupt_mask, |
| mask_ratio_floor_schedule=args.mask_ratio_floor_schedule, |
| mask_mixture_original_prob=args.mask_mixture_original_prob, |
| mask_mixture_lowk_prob=args.mask_mixture_lowk_prob, |
| mask_mixture_lowcorrupt_prob=args.mask_mixture_lowcorrupt_prob, |
| mask_mixture_block_prob=args.mask_mixture_block_prob, |
| mask_mixture_all_prob=args.mask_mixture_all_prob, |
| mask_mixture_lowk_clean_tokens=args.mask_mixture_lowk_clean_tokens, |
| mask_mixture_lowcorrupt_tokens=args.mask_mixture_lowcorrupt_tokens, |
| mask_mixture_block_tokens=args.mask_mixture_block_tokens, |
| clean_state_mode=args.clean_state_mode, |
| return_dense_targets=(args.target_loss == "soft_ce" or args.meanflow_weight > 0), |
| ) |
|
|
|
|
| def maybe_make_prefix_block_masks( |
| args: argparse.Namespace, |
| attn_mask: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor | None]: |
| prob = float(args.prefix_block_prob) |
| if prob <= 0.0: |
| return attn_mask, None |
| device = attn_mask.device |
| batch, length = attn_mask.shape |
| block_len = max(1, int(args.prefix_block_len)) |
| use_prefix = torch.rand(batch, device=device) < prob |
| if not bool(use_prefix.any().item()): |
| return attn_mask, None |
|
|
| effective_attn = attn_mask.clone() |
| force_corrupt = torch.zeros_like(attn_mask, dtype=torch.bool) |
| for b in use_prefix.nonzero(as_tuple=False).flatten().tolist(): |
| valid = attn_mask[b].nonzero(as_tuple=False).flatten() |
| if valid.numel() == 0: |
| continue |
| n_valid = int(valid.numel()) |
| n_blocks = max(1, math.ceil(n_valid / block_len)) |
| block_idx = int(torch.randint(0, n_blocks, (1,), device=device).item()) |
| start = block_idx * block_len |
| end = min(start + block_len, n_valid) |
| force_corrupt[b, valid[start:end]] = True |
| if end < n_valid: |
| effective_attn[b, valid[end:]] = False |
| return effective_attn, force_corrupt |
|
|
|
|
| def dataloader_perf_kwargs(args: argparse.Namespace, device: torch.device) -> dict: |
| kwargs = { |
| "pin_memory": (device.type == "cuda" and args.num_workers > 0), |
| "persistent_workers": args.num_workers > 0, |
| } |
| if args.num_workers > 0 and args.dataloader_prefetch_factor > 0: |
| kwargs["prefetch_factor"] = int(args.dataloader_prefetch_factor) |
| return kwargs |
|
|
|
|
| def unwrap_model(model: torch.nn.Module) -> torch.nn.Module: |
| if isinstance(model, DistributedDataParallel): |
| model = model.module |
| return getattr(model, "_orig_mod", model) |
|
|
|
|
| def rank_zero(rank: int) -> bool: |
| return rank == 0 |
|
|
|
|
| def save_checkpoint( |
| path: Path, |
| model: torch.nn.Module, |
| optimizer: torch.optim.Optimizer, |
| args: argparse.Namespace, |
| tokenizer: BpeTextTokenizer, |
| step: int, |
| scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, |
| ema_state: dict[str, torch.Tensor] | None = None, |
| ) -> None: |
| path.parent.mkdir(parents=True, exist_ok=True) |
| payload = { |
| "model": model.state_dict(), |
| "optimizer": optimizer.state_dict(), |
| "args": vars(args), |
| "step": step, |
| "vocab_size": int(getattr(args, "effective_vocab_size", 0) or tokenizer.vocab_size), |
| "eos_id": tokenizer.eos_id, |
| } |
| if ema_state is not None: |
| model_state = model.state_dict() |
| payload["ema_model"] = { |
| k: ema_state[k].to(dtype=v.dtype, device="cpu") if k in ema_state else v.detach().cpu() |
| for k, v in model_state.items() |
| } |
| if scheduler is not None: |
| payload["scheduler"] = scheduler.state_dict() |
| tmp_path = path.with_suffix(path.suffix + ".tmp") |
| torch.save(payload, tmp_path) |
| tmp_path.replace(path) |
|
|
|
|
| @torch.no_grad() |
| def run_demo(args, model, tokenizer, batch, device) -> None: |
| model.eval() |
| ids = batch["ids"][:1].to(device) |
| attn_mask = batch["attn_mask"][:1].to(device) |
| target = tokenizer.decode(ids[0].tolist()) |
| print(f"[demo] target: {target[:500]}", flush=True) |
| for ratio in [float(x) for x in args.demo_mask_ratios.split(",") if x.strip()]: |
| init, mask = fill_blank_init( |
| ids, |
| tokenizer.vocab_size, |
| args.target_prob, |
| ratio, |
| args.demo_fill_t, |
| args.eps, |
| noise_mode=args.noise_init, |
| noise_sigma=args.noise_sigma, |
| ) |
| final = soft_residual_decode( |
| model, |
| init, |
| attn_mask, |
| args.infer_steps, |
| args.decode_damping, |
| args.max_gamma, |
| args.eps, |
| solver=args.decode_solver, |
| ) |
| print(f"[demo mask_ratio={ratio:.2f}] init : {tokenizer.decode(init.argmax(-1)[0].tolist())[:500]}", flush=True) |
| print(f"[demo mask_ratio={ratio:.2f}] final: {tokenizer.decode(final.argmax(-1)[0].tolist())[:500]}", flush=True) |
| print(f"[demo mask_ratio={ratio:.2f}] corrupt_acc_init={float(((init.argmax(-1)==ids)&mask).float().sum()/mask.float().sum().clamp_min(1)):.4f} corrupt_acc_final={float(((final.argmax(-1)==ids)&mask).float().sum()/mask.float().sum().clamp_min(1)):.4f}", flush=True) |
| model.train() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| t_bin_edges = parse_t_bin_edges(args.log_t_bin_edges) |
| world_size = int(os.environ.get("WORLD_SIZE", "1")) |
| rank = int(os.environ.get("RANK", "0")) |
| local_rank = int(os.environ.get("LOCAL_RANK", "0")) |
| ddp = world_size > 1 |
| if ddp: |
| if torch.cuda.is_available(): |
| torch.cuda.set_device(local_rank) |
| dist.init_process_group(backend="nccl") |
| if args.global_batch_size > 0: |
| args.grad_accum = max(1, math.ceil(args.global_batch_size / max(1, args.batch_size * world_size))) |
| args.resolved_detokenizer = infer_detokenizer_name(raw_path=args.data_path, explicit=args.detokenizer) |
| seed_all(args.seed + rank) |
| device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") |
| if device.type == "cuda": |
| torch.backends.cuda.matmul.allow_tf32 = bool(args.allow_tf32) |
| torch.backends.cudnn.allow_tf32 = bool(args.allow_tf32) |
|
|
| def ddp_barrier() -> None: |
| if not ddp: |
| return |
| if device.type == "cuda": |
| dist.barrier(device_ids=[local_rank]) |
| else: |
| dist.barrier() |
|
|
| tokenizer = BpeTextTokenizer.from_file(args.tokenizer_path) |
| model_vocab_size = int(args.vocab_size_override) if int(args.vocab_size_override) > 0 else int(tokenizer.vocab_size) |
| args.effective_vocab_size = model_vocab_size |
| categorical_wrong_corpus_unigram_probs = load_corpus_unigram_probs( |
| args.categorical_wrong_corpus_unigram_path, |
| model_vocab_size, |
| args.categorical_wrong_corpus_unigram_alpha, |
| device, |
| ) |
| if rank_zero(rank) and categorical_wrong_corpus_unigram_probs is not None: |
| top_probs, top_ids = torch.topk(categorical_wrong_corpus_unigram_probs.detach().cpu(), k=min(10, model_vocab_size)) |
| top_summary = ",".join(f"{int(i)}:{float(p):.3e}" for i, p in zip(top_ids, top_probs)) |
| print( |
| "loaded_categorical_wrong_corpus_unigram=" |
| f"{args.categorical_wrong_corpus_unigram_path} " |
| f"alpha={args.categorical_wrong_corpus_unigram_alpha} top={top_summary}", |
| flush=True, |
| ) |
| shuffle_epoch_sampler = None |
| if args.elf_conditional_hf: |
| if args.wrap or args.owt_cached_chunks or args.online_chunk_shuffle or args.streaming or args.record_pad_truncate: |
| raise ValueError("--elf_conditional_hf is mutually exclusive with text streaming/packing modes") |
| if args.conditional_pad_token == "pad": |
| conditional_pad_id = tokenizer.pad_id if tokenizer.pad_id is not None else tokenizer.eos_id |
| loss_on_pad = False |
| else: |
| conditional_pad_id = tokenizer.eos_id |
| loss_on_pad = True |
| dataset = load_elf_conditional_dataset(args.data_path, max_records=args.max_records) |
| sampler = DistributedSampler(dataset, shuffle=True, seed=args.seed) if ddp else RandomSampler(dataset) |
| loader = DataLoader( |
| dataset, |
| batch_size=args.batch_size, |
| sampler=sampler, |
| collate_fn=ELFConditionalCollator( |
| conditional_pad_id, |
| max_len=args.max_len, |
| max_input_len=args.max_input_len, |
| loss_on_pad=loss_on_pad, |
| ), |
| num_workers=args.num_workers, |
| **dataloader_perf_kwargs(args, device), |
| drop_last=True, |
| ) |
| shuffle_epoch_sampler = sampler |
| dataset_size = ( |
| f"elf_conditional_hf:{len(dataset)}" |
| f":max_input_len={args.max_input_len}:pad={conditional_pad_id}:loss_on_pad={int(loss_on_pad)}" |
| ) |
| elif args.record_pad_truncate: |
| if args.wrap or args.owt_cached_chunks or args.online_chunk_shuffle: |
| raise ValueError("--record_pad_truncate is mutually exclusive with --wrap/--owt_cached_chunks/--online_chunk_shuffle") |
| if args.record_pad_token == "pad": |
| record_pad_id = tokenizer.pad_id if tokenizer.pad_id is not None else tokenizer.eos_id |
| else: |
| record_pad_id = tokenizer.eos_id |
| dataset = RecordPadTruncateTextSequenceDataset( |
| args.data_path, |
| tokenizer, |
| max_len=args.max_len, |
| text_column=args.text_column, |
| txt_record_mode=args.txt_record_mode, |
| openwebtext_split=args.openwebtext_split, |
| max_records_per_epoch=args.max_records, |
| detokenizer=args.resolved_detokenizer, |
| add_eos=args.record_add_eos, |
| add_special_tokens=args.record_add_special_tokens, |
| shuffle_buffer_size=args.record_shuffle_buffer, |
| seed=args.seed, |
| epoch=0, |
| ) |
| shuffle_epoch_sampler = dataset |
| loader = DataLoader( |
| dataset, |
| batch_size=args.batch_size, |
| collate_fn=FixedPadCollator(record_pad_id, max_len=args.max_len), |
| num_workers=args.num_workers, |
| **dataloader_perf_kwargs(args, device), |
| drop_last=True, |
| ) |
| dataset_size = ( |
| "record_pad_truncate" |
| f":pad={record_pad_id}:add_eos={int(args.record_add_eos)}" |
| f":add_special={int(args.record_add_special_tokens)}" |
| f":shuffle_buffer={args.record_shuffle_buffer}" |
| ) |
| elif args.owt_cached_chunks: |
| if not args.wrap or args.wrap_mode != "stream": |
| raise ValueError("--owt_cached_chunks requires --wrap --wrap_mode stream") |
| if not args.owt_chunk_cache_dir: |
| raise ValueError("--owt_cached_chunks requires --owt_chunk_cache_dir") |
| if args.openwebtext_split not in {"train_minus_100k", "all"}: |
| raise ValueError("--owt_cached_chunks is intended for OWT train/all splits") |
| if ddp and not rank_zero(rank): |
| ddp_barrier() |
| dataset = CachedWrappedTextSequenceDataset( |
| args.data_path, |
| tokenizer, |
| max_len=args.max_len, |
| cache_dir=args.owt_chunk_cache_dir, |
| text_column=args.text_column, |
| txt_record_mode=args.txt_record_mode, |
| openwebtext_split=args.openwebtext_split, |
| max_records=args.max_records, |
| detokenizer=args.resolved_detokenizer, |
| rebuild=args.owt_chunk_cache_rebuild, |
| build_cache=rank_zero(rank), |
| write_batch_size=args.owt_chunk_cache_write_batch, |
| ) |
| if ddp and rank_zero(rank): |
| ddp_barrier() |
| if args.owt_exact_repeat_per_chunk > 0: |
| sampler = ExactRepeatDistributedSampler( |
| len(dataset), |
| args.owt_exact_repeat_per_chunk, |
| seed=args.seed, |
| rank=rank, |
| world_size=world_size, |
| ) |
| shuffle_epoch_sampler = None |
| else: |
| sampler = DistributedSampler(dataset, shuffle=True, seed=args.seed) if ddp else RandomSampler(dataset) |
| shuffle_epoch_sampler = sampler |
| loader = DataLoader( |
| dataset, |
| batch_size=args.batch_size, |
| sampler=sampler, |
| collate_fn=EosPadCollator(tokenizer.eos_id, max_len=args.max_len), |
| num_workers=args.num_workers, |
| **dataloader_perf_kwargs(args, device), |
| drop_last=False, |
| ) |
| dataset_size = f"owt_cached_chunks:{len(dataset)}" |
| elif args.wrap and args.online_chunk_shuffle: |
| if args.wrap_mode != "stream": |
| raise ValueError("--online_chunk_shuffle currently requires --wrap_mode stream") |
| dataset = ShuffledWrappedStreamingTextSequenceDataset( |
| args.data_path, |
| tokenizer, |
| max_len=args.max_len, |
| text_column=args.text_column, |
| txt_record_mode=args.txt_record_mode, |
| openwebtext_split=args.openwebtext_split, |
| max_records_per_epoch=args.max_records, |
| detokenizer=args.resolved_detokenizer, |
| chunk_buffer_size=args.online_chunk_shuffle_buffer, |
| seed=args.seed, |
| epoch=0, |
| ) |
| shuffle_epoch_sampler = dataset |
| loader = DataLoader( |
| dataset, |
| batch_size=args.batch_size, |
| collate_fn=EosPadCollator(tokenizer.eos_id, max_len=args.max_len), |
| num_workers=args.num_workers, |
| **dataloader_perf_kwargs(args, device), |
| drop_last=False, |
| ) |
| dataset_size = f"wrapped_stream_online_shuffle:{args.online_chunk_shuffle_buffer}" |
| elif args.wrap: |
| dataset = WrappedStreamingTextSequenceDataset( |
| args.data_path, |
| tokenizer, |
| max_len=args.max_len, |
| text_column=args.text_column, |
| txt_record_mode=args.txt_record_mode, |
| openwebtext_split=args.openwebtext_split, |
| max_records_per_epoch=args.max_records, |
| detokenizer=args.resolved_detokenizer, |
| wrap_mode=args.wrap_mode, |
| wrap_record_buffer_size=args.wrap_record_buffer_size, |
| ) |
| loader = DataLoader( |
| dataset, |
| batch_size=args.batch_size, |
| collate_fn=EosPadCollator(tokenizer.eos_id, max_len=args.max_len), |
| num_workers=args.num_workers, |
| **dataloader_perf_kwargs(args, device), |
| drop_last=False, |
| ) |
| dataset_size = f"wrapped_{args.wrap_mode}" |
| elif args.streaming: |
| dataset = StreamingTextSequenceDataset( |
| args.data_path, |
| tokenizer, |
| max_len=args.max_len, |
| text_column=args.text_column, |
| txt_record_mode=args.txt_record_mode, |
| openwebtext_split=args.openwebtext_split, |
| max_records_per_epoch=args.max_records, |
| stride=args.stride, |
| detokenizer=args.resolved_detokenizer, |
| ) |
| loader = DataLoader( |
| dataset, |
| batch_size=args.batch_size, |
| collate_fn=EosPadCollator(tokenizer.eos_id, max_len=args.max_len), |
| num_workers=args.num_workers, |
| **dataloader_perf_kwargs(args, device), |
| drop_last=False, |
| ) |
| dataset_size = "streaming" |
| else: |
| dataset = TextSequenceDataset( |
| args.data_path, |
| tokenizer, |
| max_len=args.max_len, |
| text_column=args.text_column, |
| txt_record_mode=args.txt_record_mode, |
| openwebtext_split=args.openwebtext_split, |
| max_records=args.max_records, |
| stride=args.stride, |
| detokenizer=args.resolved_detokenizer, |
| ) |
| sampler = DistributedSampler(dataset, shuffle=True, seed=args.seed) if ddp else RandomSampler(dataset) |
| loader = DataLoader( |
| dataset, |
| batch_size=args.batch_size, |
| sampler=sampler, |
| collate_fn=EosPadCollator(tokenizer.eos_id, max_len=args.max_len), |
| num_workers=args.num_workers, |
| **dataloader_perf_kwargs(args, device), |
| drop_last=False, |
| ) |
| dataset_size = len(dataset) |
| model = EndpointPredictor( |
| vocab_size=model_vocab_size, |
| max_len=args.max_len, |
| d_model=args.d_model, |
| n_heads=args.n_heads, |
| n_layers=args.n_layers, |
| dim_ff=args.dim_ff, |
| dropout=args.dropout, |
| model_type=args.model_type, |
| cond_dim=args.cond_dim, |
| input_format=args.state_format, |
| output_bias=args.output_bias, |
| norm_type=args.norm_type, |
| ).to(device) |
| if args.activation_checkpointing and hasattr(model, "set_activation_checkpointing"): |
| model.set_activation_checkpointing( |
| True, |
| interval=args.activation_checkpoint_interval, |
| scope=args.activation_checkpoint_scope, |
| ) |
| if args.torch_compile: |
| model = torch.compile(model, mode=args.compile_mode if args.compile_mode != "default" else None) |
| trainable_model = model |
| if ddp: |
| trainable_model = DistributedDataParallel( |
| model, |
| device_ids=[local_rank], |
| output_device=local_rank, |
| static_graph=bool(args.ddp_static_graph), |
| gradient_as_bucket_view=bool(args.ddp_gradient_as_bucket_view), |
| ) |
| optimizer = configure_adamw_optimizer(unwrap_model(trainable_model), args, device) |
| ema_state = init_ema_state(unwrap_model(trainable_model)) if args.ema_decay > 0 else None |
|
|
| def lr_lambda(step: int) -> float: |
| if args.warmup_steps > 0 and step < args.warmup_steps: |
| return max(1e-12, float(step + 1) / float(args.warmup_steps)) |
| if args.lr_schedule == "constant_warmup": |
| return 1.0 |
| progress = (step - args.warmup_steps) / max(1, args.total_steps - args.warmup_steps) |
| min_lr_ratio = float(args.min_lr) / max(float(args.lr), 1e-12) |
| return max(min_lr_ratio, 0.5 * (1.0 + math.cos(math.pi * min(max(progress, 0.0), 1.0)))) |
|
|
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
| start_step = 1 |
| if args.init_model_path: |
| if args.resume_path: |
| raise ValueError("--init_model_path is mutually exclusive with --resume_path") |
| ckpt = torch.load(args.init_model_path, map_location=device) |
| init_state = dict(ckpt["model"]) |
| model_state = unwrap_model(trainable_model).state_dict() |
| if "pos_embed" in init_state and "pos_embed" in model_state and init_state["pos_embed"].shape != model_state["pos_embed"].shape: |
| if args.init_pos_embed_mode == "strict": |
| raise ValueError( |
| f"init pos_embed shape {tuple(init_state['pos_embed'].shape)} != " |
| f"model pos_embed shape {tuple(model_state['pos_embed'].shape)}; " |
| "set --init_pos_embed_mode repeat or interpolate to adapt" |
| ) |
| src = init_state["pos_embed"] |
| target_len = int(model_state["pos_embed"].shape[1]) |
| if src.shape[1] > target_len: |
| init_state["pos_embed"] = src[:, :target_len].contiguous() |
| elif args.init_pos_embed_mode == "repeat": |
| reps = math.ceil(target_len / int(src.shape[1])) |
| init_state["pos_embed"] = src.repeat(1, reps, 1)[:, :target_len].contiguous() |
| else: |
| init_state["pos_embed"] = F.interpolate( |
| src.transpose(1, 2), |
| size=target_len, |
| mode="linear", |
| align_corners=True, |
| ).transpose(1, 2).contiguous() |
| if rank_zero(rank): |
| print( |
| f"adapted_init_pos_embed from={tuple(src.shape)} to={tuple(init_state['pos_embed'].shape)} " |
| f"mode={args.init_pos_embed_mode}", |
| flush=True, |
| ) |
| for optional_bias_key in ("output_layer.linear.bias", "out_proj.bias"): |
| if optional_bias_key in init_state and optional_bias_key not in model_state: |
| init_state.pop(optional_bias_key) |
| if rank_zero(rank): |
| print(f"dropped_init_output_bias key={optional_bias_key} because current model has output_bias=False", flush=True) |
| unwrap_model(trainable_model).load_state_dict(init_state) |
| if ema_state is not None and "ema_model" in ckpt: |
| for k, v in ckpt["ema_model"].items(): |
| if k in ema_state: |
| ema_state[k].copy_(v.to(device=ema_state[k].device, dtype=ema_state[k].dtype)) |
| if rank_zero(rank): |
| print(f"initialized_model_from={args.init_model_path} start_step={start_step}", flush=True) |
| if args.resume_path: |
| ckpt = torch.load(args.resume_path, map_location=device) |
| unwrap_model(trainable_model).load_state_dict(ckpt["model"]) |
| optimizer.load_state_dict(ckpt["optimizer"]) |
| if "scheduler" in ckpt: |
| scheduler.load_state_dict(ckpt["scheduler"]) |
| else: |
| scheduler.last_epoch = int(ckpt.get("step", 0)) |
| if ema_state is not None and "ema_model" in ckpt: |
| for k, v in ckpt["ema_model"].items(): |
| if k in ema_state: |
| ema_state[k].copy_(v.to(device=ema_state[k].device, dtype=ema_state[k].dtype)) |
| start_step = int(ckpt.get("step", 0)) + 1 |
| if rank_zero(rank): |
| print(f"resumed_from={args.resume_path} start_step={start_step}", flush=True) |
|
|
| save_dir = Path(args.save_dir) |
| if rank_zero(rank): |
| save_dir.mkdir(parents=True, exist_ok=True) |
| (save_dir / "args.json").write_text(json.dumps(vars(args), indent=2), encoding="utf-8") |
| print(json.dumps({ |
| "device": str(device), |
| "rank": rank, |
| "world_size": world_size, |
| "samples": dataset_size, |
| "vocab_size": model_vocab_size, |
| "tokenizer_vocab_size": tokenizer.vocab_size, |
| "save_dir": str(save_dir), |
| "batch_size": args.batch_size, |
| "grad_accum": args.grad_accum, |
| "effective_batch_size": args.batch_size * args.grad_accum * world_size, |
| "global_batch_size": args.global_batch_size, |
| "lr_schedule": args.lr_schedule, |
| "optimizer": args.optimizer, |
| "warmup_steps": args.warmup_steps, |
| "min_lr": args.min_lr, |
| "weight_decay": args.weight_decay, |
| "output_weight_decay": args.output_weight_decay, |
| "adamw_param_groups": args.adamw_param_groups, |
| "adam_beta1": args.adam_beta1, |
| "adam_beta2": args.adam_beta2, |
| "adam_eps": args.adam_eps, |
| "muon_momentum": args.muon_momentum, |
| "muon_ns_steps": args.muon_ns_steps, |
| "muon_update_scale": args.muon_update_scale, |
| "ema_decay": args.ema_decay, |
| "ema_start_step": args.ema_start_step, |
| "model_type": args.model_type, |
| "output_bias": args.output_bias, |
| "norm_type": args.norm_type, |
| "t_sampling_mode": args.t_sampling_mode, |
| "t_sampling_power": args.t_sampling_power, |
| "t_sampling_eps": args.t_sampling_eps, |
| "dual_t": args.dual_t, |
| "corrupt_t_mode": args.corrupt_t_mode, |
| "corrupt_min_t": args.corrupt_min_t, |
| "corrupt_max_t": args.corrupt_max_t, |
| "prefix_block_prob": args.prefix_block_prob, |
| "prefix_block_len": args.prefix_block_len, |
| "mask_ratio_floor_schedule": args.mask_ratio_floor_schedule, |
| "dirichlet_endpoint_mode": args.dirichlet_endpoint_mode, |
| "dirichlet_semantic_t_mode": args.dirichlet_semantic_t_mode, |
| "dirichlet_semantic_t_value": args.dirichlet_semantic_t_value, |
| "dirichlet_semantic_t_curve": args.dirichlet_semantic_t_curve, |
| "dirichlet_semantic_t_power": args.dirichlet_semantic_t_power, |
| "endpoint_sequence_random_prob_alpha": args.endpoint_sequence_random_prob_alpha, |
| "categorical_wrong_from_full_vocab": args.categorical_wrong_from_full_vocab, |
| "categorical_wrong_from_batch_valid_tokens": args.categorical_wrong_from_batch_valid_tokens, |
| "categorical_wrong_basin_token_ids": args.categorical_wrong_basin_token_ids, |
| "categorical_wrong_basin_prob": args.categorical_wrong_basin_prob, |
| "categorical_wrong_unigram_prob": args.categorical_wrong_unigram_prob, |
| "categorical_wrong_uniform_prob": args.categorical_wrong_uniform_prob, |
| "categorical_wrong_corpus_unigram_path": args.categorical_wrong_corpus_unigram_path, |
| "categorical_wrong_corpus_unigram_alpha": args.categorical_wrong_corpus_unigram_alpha, |
| "categorical_wrong_basin_shared_prob": args.categorical_wrong_basin_shared_prob, |
| "categorical_wrong_unigram_shared_prob": args.categorical_wrong_unigram_shared_prob, |
| "mask_mixture_original_prob": args.mask_mixture_original_prob, |
| "mask_mixture_lowk_prob": args.mask_mixture_lowk_prob, |
| "mask_mixture_lowcorrupt_prob": args.mask_mixture_lowcorrupt_prob, |
| "mask_mixture_block_prob": args.mask_mixture_block_prob, |
| "mask_mixture_all_prob": args.mask_mixture_all_prob, |
| "mask_mixture_lowk_clean_tokens": args.mask_mixture_lowk_clean_tokens, |
| "mask_mixture_lowcorrupt_tokens": args.mask_mixture_lowcorrupt_tokens, |
| "mask_mixture_block_tokens": args.mask_mixture_block_tokens, |
| "simplex_bridge_sampler": args.simplex_bridge_sampler, |
| "logistic_normal_sigma_min": args.logistic_normal_sigma_min, |
| "logistic_normal_sigma_max": args.logistic_normal_sigma_max, |
| "logistic_normal_tau_min": args.logistic_normal_tau_min, |
| "logistic_normal_tau_max": args.logistic_normal_tau_max, |
| "torch_compile": args.torch_compile, |
| "compile_mode": args.compile_mode, |
| "state_format": args.state_format, |
| "target_loss": args.target_loss, |
| "meanflow_weight": args.meanflow_weight, |
| "rollout_train_prob": args.rollout_train_prob, |
| "rollout_train_steps": args.rollout_train_steps, |
| "rollout_train_infer_steps": args.rollout_train_infer_steps, |
| "rollout_train_temp": args.rollout_train_temp, |
| "rollout_train_max_gamma": args.rollout_train_max_gamma, |
| "rollout_train_corrupt_only": args.rollout_train_corrupt_only, |
| "rollout_train_samplewise": args.rollout_train_samplewise, |
| "rollout_train_compute_always": args.rollout_train_compute_always, |
| "bridge_noise_init": args.bridge_noise_init, |
| "noise_sigma": args.noise_sigma, |
| "allow_tf32": args.allow_tf32, |
| "activation_checkpointing": args.activation_checkpointing, |
| "activation_checkpoint_interval": args.activation_checkpoint_interval, |
| "activation_checkpoint_scope": args.activation_checkpoint_scope, |
| "ddp_static_graph": args.ddp_static_graph, |
| "ddp_gradient_as_bucket_view": args.ddp_gradient_as_bucket_view, |
| "blocking_data_transfer": args.blocking_data_transfer, |
| "dataloader_prefetch_factor": args.dataloader_prefetch_factor, |
| "full_train_stats": args.full_train_stats, |
| "record_pad_truncate": args.record_pad_truncate, |
| "record_add_eos": args.record_add_eos, |
| "record_add_special_tokens": args.record_add_special_tokens, |
| "record_pad_token": args.record_pad_token, |
| "record_shuffle_buffer": args.record_shuffle_buffer, |
| "wrap": args.wrap, |
| "wrap_mode": args.wrap_mode, |
| "wrap_record_buffer_size": args.wrap_record_buffer_size, |
| "owt_cached_chunks": args.owt_cached_chunks, |
| "owt_chunk_cache_dir": args.owt_chunk_cache_dir, |
| "owt_chunk_cache_rebuild": args.owt_chunk_cache_rebuild, |
| "owt_chunk_cache_write_batch": args.owt_chunk_cache_write_batch, |
| "owt_exact_repeat_per_chunk": args.owt_exact_repeat_per_chunk, |
| "online_chunk_shuffle": args.online_chunk_shuffle, |
| "online_chunk_shuffle_buffer": args.online_chunk_shuffle_buffer, |
| "openwebtext_split": args.openwebtext_split, |
| "detokenizer": args.detokenizer, |
| "resolved_detokenizer": args.resolved_detokenizer, |
| "num_workers": args.num_workers, |
| "latest_every": args.latest_every, |
| "resume_path": args.resume_path, |
| }, indent=2), flush=True) |
|
|
| use_amp = args.bf16 and device.type == "cuda" |
| data_epoch = 0 |
| data_iter = iter(loader) |
| running = [] |
| start = time.time() |
| trainable_model.train() |
| optimizer.zero_grad(set_to_none=True) |
| for step in range(start_step, args.total_steps + 1): |
| last_batch = None |
| for micro_step in range(args.grad_accum): |
| try: |
| batch = next(data_iter) |
| except StopIteration: |
| data_epoch += 1 |
| if shuffle_epoch_sampler is not None and hasattr(shuffle_epoch_sampler, "set_epoch"): |
| shuffle_epoch_sampler.set_epoch(data_epoch) |
| data_iter = iter(loader) |
| batch = next(data_iter) |
| last_batch = batch |
| non_blocking = device.type == "cuda" and not args.blocking_data_transfer |
| ids = batch["ids"].to(device, non_blocking=non_blocking) |
| attn_mask = batch["attn_mask"].to(device, non_blocking=non_blocking) |
| loss_mask = batch.get("loss_mask") |
| cond_seq_mask = batch.get("cond_seq_mask") |
| if loss_mask is not None: |
| loss_mask = loss_mask.to(device, non_blocking=non_blocking) |
| if cond_seq_mask is not None: |
| cond_seq_mask = cond_seq_mask.to(device, non_blocking=non_blocking) |
| bridge_input_mask = loss_mask if loss_mask is not None else attn_mask |
| bridge_attn_mask, force_corrupt_mask = maybe_make_prefix_block_masks(args, bridge_input_mask) |
| model_t = sample_time( |
| ids.size(0), |
| args.min_t, |
| args.max_t, |
| device, |
| mode=args.t_sampling_mode, |
| power=args.t_sampling_power, |
| eps=args.t_sampling_eps, |
| ) |
| bridge = make_bridge( |
| args, |
| ids, |
| bridge_attn_mask, |
| model_vocab_size, |
| force_t=resolve_bridge_force_t(args, model_t), |
| force_corrupt_mask=force_corrupt_mask, |
| categorical_wrong_corpus_unigram_probs=categorical_wrong_corpus_unigram_probs, |
| ) |
| if loss_mask is not None: |
| model_attn_mask = attn_mask |
| if cond_seq_mask is not None and float(args.label_drop_prob) > 0.0: |
| drop = torch.rand(ids.size(0), device=device) < float(args.label_drop_prob) |
| model_attn_mask = attn_mask & ~(cond_seq_mask & drop.view(-1, 1)) |
| bridge.attn_mask = model_attn_mask |
| loss_t_weight = sample_loss_t_weight(args, model_t) |
| sync_grad = micro_step == args.grad_accum - 1 |
| sync_context = ( |
| trainable_model.no_sync() |
| if ddp |
| and isinstance(trainable_model, DistributedDataParallel) |
| and not args.ddp_static_graph |
| and not sync_grad |
| else nullcontext() |
| ) |
| with sync_context: |
| grad_enabled_before_rollout = torch.is_grad_enabled() |
| loss_state, rollout_train_applied, rollout_train_sample_mask = maybe_make_rollout_train_state( |
| args, |
| trainable_model, |
| bridge.state, |
| model_t, |
| bridge.attn_mask, |
| bridge.corrupt_mask, |
| ) |
| grad_enabled_after_rollout = torch.is_grad_enabled() |
| with torch.amp.autocast("cuda", enabled=use_amp, dtype=torch.bfloat16): |
| logits = trainable_model(loss_state, model_t, bridge.attn_mask) |
| logits_requires_grad = bool(logits.requires_grad) |
| loss_logits = logits |
| if args.prior_center_loss_beta > 0: |
| prior_state = make_prior_center_state( |
| args, |
| ids.size(0), |
| ids.size(1), |
| model_vocab_size, |
| device, |
| ) |
| with torch.no_grad(): |
| prior_logits = trainable_model(prior_state, model_t, bridge.attn_mask) |
| loss_logits = logits.float() - float(args.prior_center_loss_beta) * prior_logits.float() |
| if args.target_loss == "soft_ce": |
| if args.loss_t_weight_mode == "none": |
| raw_loss = masked_soft_ce(loss_logits, bridge.target_probs, bridge.corrupt_mask) |
| else: |
| raw_loss = weighted_masked_soft_ce(loss_logits, bridge.target_probs, bridge.corrupt_mask, loss_t_weight) |
| else: |
| if args.loss_t_weight_mode == "none": |
| raw_loss = masked_ce(loss_logits, bridge.ids, bridge.corrupt_mask) |
| else: |
| raw_loss = weighted_masked_ce(loss_logits, bridge.ids, bridge.corrupt_mask, loss_t_weight) |
| raw_loss_requires_grad = bool(raw_loss.requires_grad) |
| if args.meanflow_weight > 0: |
| if args.state_format == "prob": |
| current_probs = loss_state |
| else: |
| current_probs = loss_state.float().exp() |
| meanflow_loss = masked_meanflow_l2( |
| logits, |
| current_probs=current_probs, |
| target_probs=bridge.target_probs, |
| mask=bridge.corrupt_mask, |
| ) |
| total_loss = raw_loss + float(args.meanflow_weight) * meanflow_loss |
| else: |
| meanflow_loss = raw_loss.new_zeros(()) |
| total_loss = raw_loss |
| if float(args.rollout_train_prob) > 0.0: |
| total_loss = total_loss + zero_param_anchor(unwrap_model(trainable_model), total_loss) |
| loss = total_loss / max(1, args.grad_accum) |
| loss.backward() |
|
|
| last_micro_step = micro_step == args.grad_accum - 1 |
| heavy_stats = args.full_train_stats or (step % args.log_every == 0 and last_micro_step) |
| collect_stats = rank_zero(rank) and ( |
| args.full_train_stats |
| or args.log_train_stats_each_step |
| or heavy_stats |
| ) |
| if collect_stats: |
| |
| |
| def scalar(x: torch.Tensor) -> float: |
| return float(x.detach().float().cpu()) |
|
|
| batch_weight = float(ids.size(0)) |
| corrupt_count = bridge.corrupt_mask.float().sum() |
| stats = { |
| "loss": float(total_loss.detach().cpu()), |
| "loss__weight": scalar(corrupt_count), |
| "loss_recon": float(raw_loss.detach().cpu()), |
| "loss_recon__weight": scalar(corrupt_count), |
| "loss_meanflow": float(meanflow_loss.detach().cpu()), |
| "loss_meanflow__weight": scalar(corrupt_count), |
| "mean_model_t": float(model_t.mean().detach().cpu()), |
| "mean_model_t__weight": batch_weight, |
| "mean_corrupt_t": float(bridge.t.mean().detach().cpu()), |
| "mean_corrupt_t__weight": batch_weight, |
| "mean_loss_t_weight": float(loss_t_weight.mean().detach().cpu()), |
| "mean_loss_t_weight__weight": batch_weight, |
| "prior_center_loss_beta": float(args.prior_center_loss_beta), |
| "rollout_train_applied": float(rollout_train_applied), |
| "rollout_train_applied__weight": batch_weight, |
| "grad_enabled_before_rollout": float(grad_enabled_before_rollout), |
| "grad_enabled_after_rollout": float(grad_enabled_after_rollout), |
| "logits_requires_grad": float(logits_requires_grad), |
| "raw_loss_requires_grad": float(raw_loss_requires_grad), |
| } |
| with torch.no_grad(): |
| pred = loss_logits.detach().argmax(dim=-1) |
| attn_mask_f = bridge.attn_mask.float() |
| corrupt_mask_f = bridge.corrupt_mask.float() |
| attn_count = attn_mask_f.sum() |
| correct_all = ((pred == bridge.ids) & bridge.attn_mask).float().sum() |
| correct_corrupt = ((pred == bridge.ids) & bridge.corrupt_mask).float().sum() |
| if scalar(attn_count) > 0.0: |
| stats.update({ |
| "acc_all": scalar(correct_all / attn_count.clamp_min(1.0)), |
| "acc_all__weight": scalar(attn_count), |
| "corrupt_frac": scalar(corrupt_count / attn_count.clamp_min(1.0)), |
| "corrupt_frac__weight": scalar(attn_count), |
| }) |
| if scalar(corrupt_count) > 0.0: |
| stats.update({ |
| "acc_corrupt": scalar(correct_corrupt / corrupt_count.clamp_min(1.0)), |
| "acc_corrupt__weight": scalar(corrupt_count), |
| "loss_corrupt": float(raw_loss.detach().cpu()), |
| "loss_corrupt__weight": scalar(corrupt_count), |
| "wrong_frac": scalar(bridge.wrong_mask.float().sum() / corrupt_count.clamp_min(1.0)), |
| "wrong_frac__weight": scalar(corrupt_count), |
| }) |
| init_pred = loss_state.detach().argmax(dim=-1) |
| init_correct = ((init_pred == bridge.ids) & bridge.corrupt_mask).float().sum() |
| stats.update({ |
| "init_acc_corrupt": scalar(init_correct / corrupt_count.clamp_min(1.0)), |
| "init_acc_corrupt__weight": scalar(corrupt_count), |
| }) |
| if args.log_t_bin_stats: |
| total_corrupt = corrupt_count.clamp_min(1.0) |
| sample_t = bridge.t.detach().float().view(-1) |
| for lo, hi in zip(t_bin_edges[:-1], t_bin_edges[1:]): |
| if hi >= t_bin_edges[-1]: |
| sample_mask = (sample_t >= lo) & (sample_t <= hi) |
| else: |
| sample_mask = (sample_t >= lo) & (sample_t < hi) |
| pos_mask = bridge.corrupt_mask & sample_mask[:, None] |
| denom = pos_mask.float().sum() |
| if scalar(denom) <= 0.0: |
| continue |
| tag = f"{lo:.1f}_{min(hi, 1.0):.1f}".replace(".", "p") |
| correct_bin = ((pred == bridge.ids) & pos_mask).float().sum() |
| stats[f"acc_corrupt_t_{tag}"] = scalar(correct_bin / denom.clamp_min(1.0)) |
| stats[f"acc_corrupt_t_{tag}__weight"] = scalar(denom) |
| stats[f"corrupt_frac_t_{tag}"] = scalar(denom / total_corrupt) |
| stats[f"corrupt_frac_t_{tag}__weight"] = scalar(corrupt_count) |
| model_for_stats = unwrap_model(trainable_model) |
| out_linear = getattr(getattr(model_for_stats, "output_layer", None), "linear", None) |
| if out_linear is None: |
| out_linear = getattr(model_for_stats, "out_proj", None) |
| if out_linear is not None: |
| out_w = out_linear.weight.detach().float() |
| out_g = out_linear.weight.grad |
| stats["out_w_norm"] = float(out_w.norm().cpu()) |
| stats["out_g_norm"] = float(out_g.detach().float().norm().cpu()) if out_g is not None else 0.0 |
| if heavy_stats: |
| stats.update(summarize_batch( |
| loss_logits.detach(), |
| bridge.ids, |
| bridge.attn_mask, |
| bridge.corrupt_mask, |
| target_probs=bridge.target_probs if args.target_loss == "soft_ce" else None, |
| include_loss=True, |
| )) |
| if args.log_t_bin_stats: |
| stats.update(summarize_t_bin_acc( |
| loss_logits.detach(), |
| bridge.ids, |
| bridge.corrupt_mask, |
| bridge.t.detach(), |
| edges=t_bin_edges, |
| )) |
| stats.update({ |
| "init_gold_top10": float(masked_topk_acc(loss_state.detach(), bridge.ids, bridge.corrupt_mask, 10).cpu()), |
| "init_gold_top100": float(masked_topk_acc(loss_state.detach(), bridge.ids, bridge.corrupt_mask, 100).cpu()), |
| }) |
| if rollout_train_sample_mask is not None: |
| applied_pos = bridge.corrupt_mask & rollout_train_sample_mask.view(-1, 1) |
| kept_pos = bridge.corrupt_mask & (~rollout_train_sample_mask.view(-1, 1)) |
| stats.update({ |
| "rollout_applied_pos_frac": float((applied_pos.float().sum() / bridge.corrupt_mask.float().sum().clamp_min(1.0)).detach().cpu()), |
| "init_acc_rollout_applied": float(masked_acc(loss_state.detach(), bridge.ids, applied_pos).cpu()), |
| "init_acc_rollout_kept": float(masked_acc(loss_state.detach(), bridge.ids, kept_pos).cpu()), |
| "logit_acc_rollout_applied": float(masked_acc(loss_logits.detach(), bridge.ids, applied_pos).cpu()), |
| "logit_acc_rollout_kept": float(masked_acc(loss_logits.detach(), bridge.ids, kept_pos).cpu()), |
| }) |
| running.append(stats) |
|
|
| if args.grad_clip > 0: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) |
| optimizer.step() |
| scheduler.step() |
| if ema_state is not None and step >= args.ema_start_step: |
| update_ema_state(ema_state, unwrap_model(trainable_model), args.ema_decay) |
| optimizer.zero_grad(set_to_none=True) |
| if step % args.log_every == 0: |
| if rank_zero(rank): |
| avg = average_stats_window(running) if running else {} |
| elapsed = time.time() - start |
| print(" ".join([ |
| f"step={step}", |
| f"micro_steps={step * args.grad_accum}", |
| f"elapsed={elapsed:.1f}s", |
| f"lr={scheduler.get_last_lr()[0]:.6e}", |
| ] + [f"{k}={v:.4f}" for k, v in avg.items()]), flush=True) |
| running = [] |
| start = time.time() |
| if rank_zero(rank) and args.eval_every > 0 and step % args.eval_every == 0: |
| run_demo(args, unwrap_model(trainable_model), tokenizer, last_batch, device) |
| if rank_zero(rank) and args.save_every > 0 and step % args.save_every == 0: |
| save_checkpoint(save_dir / f"step_{step:07d}.pt", unwrap_model(trainable_model), optimizer, args, tokenizer, step, scheduler, ema_state) |
| save_checkpoint(save_dir / "latest.pt", unwrap_model(trainable_model), optimizer, args, tokenizer, step, scheduler, ema_state) |
| elif rank_zero(rank) and args.latest_every > 0 and step % args.latest_every == 0: |
| save_checkpoint(save_dir / "latest.pt", unwrap_model(trainable_model), optimizer, args, tokenizer, step, scheduler, ema_state) |
|
|
| if rank_zero(rank): |
| save_checkpoint(save_dir / "latest.pt", unwrap_model(trainable_model), optimizer, args, tokenizer, args.total_steps, scheduler, ema_state) |
| if ddp: |
| dist.destroy_process_group() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|