fVLM-1.7B / train.py
sanps's picture
Upload train.py
f9a257d verified
#!/usr/bin/env python
"""
Unified training script for foveated VLM (all stages).
Usage:
# Single GPU
python train.py --config configs/stage1_135M.yaml
# Multi-GPU (2xA100)
torchrun --nproc_per_node=2 train.py --config configs/stage1_135M.yaml
# Dry run (verify config, dataloaders, shapes)
python train.py --config configs/stage1_135M.yaml --dry-run
"""
import argparse
import gc
import os
import sys
import time
from contextlib import nullcontext
import torch
import torch.nn as nn
import yaml
from torch.nn.parallel import DistributedDataParallel as DDP
# TF32 for free speedup on Ampere+ GPUs (RTX 3090, A100, RTX 5090, etc.)
torch.set_float32_matmul_precision("high")
torch.backends.cudnn.benchmark = True # auto-tune conv algorithms (DINO patch embed)
torch.backends.cudnn.allow_tf32 = True # TF32 for cuDNN convolutions
torch.backends.cuda.matmul.allow_tf32 = True # redundant with set_float32_matmul_precision but explicit
# nanochat: expandable segments for GPU memory allocator
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
# Local imports (flat layout — all modules at repo root).
from model import FoveatedVLM
from data import make_dataloader, make_dynamic_dataloader, create_dpo_webdataset
from collate import collate_dpo
from text_interleave import InterleavedDataLoader
from distributed import (
setup_distributed,
cleanup_distributed,
is_main_process,
get_rank,
reduce_mean,
)
from checkpoint import save_checkpoint, load_latest_checkpoint
from schedule import get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup, get_converging_schedule
from logger import TrainingLogger
from attention_viz import compute_attention_entropy, save_attention_maps
# --------------------------------------------------------------------------- #
# Config
# --------------------------------------------------------------------------- #
def parse_args():
p = argparse.ArgumentParser(description="fVLM training")
p.add_argument("--config", required=True, help="YAML config path")
p.add_argument("--dry-run", action="store_true",
help="Parse config, build model & dataloader, print shapes, exit")
return p.parse_args()
def load_config(path: str) -> dict:
with open(path) as f:
cfg = yaml.safe_load(f)
return cfg
# --------------------------------------------------------------------------- #
# Build components
# --------------------------------------------------------------------------- #
def build_model(cfg: dict, device: torch.device):
model = FoveatedVLM(
llm_name=cfg["model"]["llm"],
dino_name=cfg["model"]["dino"],
query_dim=cfg["model"].get("query_dim", 384),
visual_scale=cfg["model"].get("visual_scale", 0.14),
lambda_coarse=cfg["model"].get("lambda_coarse", 0.0),
deep_query=cfg["model"].get("deep_query", True),
use_fused_ce=cfg["model"].get("use_fused_ce", False),
)
# Initialise from a previous-stage checkpoint (Stage 2 loads Stage 1, etc.)
init_from = cfg["model"].get("init_from")
if init_from and os.path.exists(init_from):
ckpt = torch.load(init_from, map_location="cpu", weights_only=False)
model.load_state_dict(ckpt["model_state_dict"])
if is_main_process():
print(f" Loaded weights from {init_from}")
if cfg["model"].get("gradient_checkpointing", False):
model.enable_gradient_checkpointing()
model = model.to(device)
# channels_last for DINO conv layers (patch embedding) — better tensor core util
if hasattr(model, 'encoder') and hasattr(model.encoder, 'dino'):
model.encoder.dino = model.encoder.dino.to(memory_format=torch.channels_last)
# ---- Freeze parameters based on config ----
dino_module = getattr(model, 'encoder', None)
if dino_module is not None:
dino_module = getattr(dino_module, 'dino', None)
if dino_module is None:
dino_module = getattr(model, 'dino', None)
if cfg["model"].get("freeze_dino", False) and dino_module is not None:
for p in dino_module.parameters():
p.requires_grad = False
if is_main_process():
print(" Frozen: DINO encoder")
if cfg["model"].get("freeze_llm", False):
for p in model.llm.parameters():
p.requires_grad = False
if is_main_process():
print(" Frozen: LLM backbone")
return model
def _get_tokenizer(cfg: dict):
"""Lazy-load the tokenizer for on-the-fly tokenization of raw captions."""
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(cfg["model"]["llm"])
if tok.pad_token is None:
tok.pad_token = tok.eos_token
return tok
def build_train_loader(cfg: dict, epoch: int = 0):
"""Build the training dataloader (vision + optional text interleave)."""
stage = cfg.get("stage", 1)
tokenizer = _get_tokenizer(cfg)
use_dynamic = cfg["training"].get("dynamic_batching", False)
if use_dynamic:
vision_loader = make_dynamic_dataloader(
shard_pattern=cfg["data"]["train_shards"],
max_total_frames=cfg["training"].get("max_total_frames", 512),
max_batch_size=cfg["training"].get("max_batch_size", 64),
max_frames=cfg["data"].get("max_frames", 64),
min_frames=cfg["data"].get("min_frames", 0),
shuffle=True,
seed=cfg["training"].get("seed", 42),
epoch=epoch,
num_workers=cfg["data"].get("num_workers", 12),
prefetch_factor=cfg["data"].get("prefetch_factor", 8),
tokenizer=tokenizer,
stage=stage,
replicate_image_frames=cfg["data"].get("replicate_image_frames", 1),
)
else:
vision_loader = make_dataloader(
shard_pattern=cfg["data"]["train_shards"],
batch_size=cfg["training"]["batch_size"],
max_frames=cfg["data"].get("max_frames", 64),
min_frames=cfg["data"].get("min_frames", 0),
shuffle=True,
seed=cfg["training"].get("seed", 42),
epoch=epoch,
num_workers=cfg["data"].get("num_workers", 12),
prefetch_factor=cfg["data"].get("prefetch_factor", 8),
tokenizer=tokenizer,
stage=stage,
replicate_image_frames=cfg["data"].get("replicate_image_frames", 1),
)
text_ratio = cfg["data"].get("text_ratio", 0.0)
if text_ratio > 0 and cfg["data"].get("text_shards"):
text_loader = make_dataloader(
shard_pattern=cfg["data"]["text_shards"],
batch_size=cfg["training"]["batch_size"],
max_frames=1,
shuffle=True,
seed=cfg["training"].get("seed", 42),
epoch=epoch,
num_workers=max(1, cfg["data"].get("num_workers", 12) // 2),
prefetch_factor=cfg["data"].get("prefetch_factor", 8),
tokenizer=tokenizer,
stage=stage,
)
return InterleavedDataLoader(
vision_loader=vision_loader,
text_loader=text_loader,
text_ratio=text_ratio,
seed=cfg["training"].get("seed", 42) + epoch,
)
return vision_loader
def build_dpo_train_loader(cfg: dict, epoch: int = 0):
"""Build the training dataloader for DPO (preference) data."""
tokenizer = _get_tokenizer(cfg)
dataset = create_dpo_webdataset(
shard_pattern=cfg["data"]["train_shards"],
tokenizer=tokenizer,
max_frames=cfg["data"].get("max_frames", 64),
shuffle=True,
seed=cfg["training"].get("seed", 42),
epoch=epoch,
num_workers=cfg["data"].get("num_workers", 2),
replicate_image_frames=cfg["data"].get("replicate_image_frames", 1),
)
loader = torch.utils.data.DataLoader(
dataset,
batch_size=cfg["training"]["batch_size"],
num_workers=cfg["data"].get("num_workers", 2),
collate_fn=collate_dpo,
pin_memory=True,
prefetch_factor=cfg["data"].get("prefetch_factor", 2),
persistent_workers=cfg["data"].get("num_workers", 2) > 0,
)
return loader
def build_reference_model(cfg: dict, device: torch.device):
"""
Build a frozen reference model for DPO training.
The reference model is a copy of the policy model loaded from the same
init_from checkpoint (the Stage 2 best). All parameters are frozen
and the model is set to eval mode.
"""
ref_model = build_model(cfg, device)
ref_model.eval()
for p in ref_model.parameters():
p.requires_grad = False
return ref_model
def compute_dpo_loss(
policy_chosen_logps: torch.Tensor,
policy_rejected_logps: torch.Tensor,
ref_chosen_logps: torch.Tensor,
ref_rejected_logps: torch.Tensor,
beta: float = 0.1,
) -> dict:
"""
Compute the DPO loss and reward accuracy.
DPO loss = -log_sigmoid(β * ((π_chosen - π_ref_chosen) - (π_rejected - π_ref_rejected)))
Parameters
----------
policy_chosen_logps : [B] log-probs from policy on chosen
policy_rejected_logps : [B] log-probs from policy on rejected
ref_chosen_logps : [B] log-probs from reference on chosen
ref_rejected_logps : [B] log-probs from reference on rejected
beta : float DPO temperature
Returns
-------
dict with keys:
loss : scalar DPO loss
reward_accuracy : float, fraction where chosen is preferred
chosen_reward : [B] implicit reward for chosen
rejected_reward : [B] implicit reward for rejected
"""
# Implicit rewards: β * (log π_policy - log π_ref)
chosen_reward = beta * (policy_chosen_logps - ref_chosen_logps)
rejected_reward = beta * (policy_rejected_logps - ref_rejected_logps)
# DPO loss: -log σ(r_chosen - r_rejected)
logits = chosen_reward - rejected_reward
loss = -torch.nn.functional.logsigmoid(logits).mean()
# Reward accuracy: fraction where chosen is preferred over rejected
reward_accuracy = (logits > 0).float().mean().item()
return {
"loss": loss,
"reward_accuracy": reward_accuracy,
"chosen_reward": chosen_reward,
"rejected_reward": rejected_reward,
}
def build_val_loader(cfg: dict):
val_shards = cfg["data"].get("val_shards")
if not val_shards:
return None
stage = cfg.get("stage", 1)
tokenizer = _get_tokenizer(cfg)
return make_dataloader(
shard_pattern=val_shards,
batch_size=cfg["training"]["batch_size"],
max_frames=cfg["data"].get("max_frames", 64),
shuffle=False,
num_workers=0, # load in main process — eval is small, avoids RAM spike
tokenizer=tokenizer,
stage=stage,
)
# --------------------------------------------------------------------------- #
# Evaluation
# --------------------------------------------------------------------------- #
@torch.no_grad()
def evaluate(model, val_loader, device, amp_dtype, use_amp, cfg,
save_attn_dir=None, step=0):
"""Run validation and return dict of average losses + attention entropy."""
model.eval()
raw_model = model.module if hasattr(model, "module") else model
is_foveated = hasattr(raw_model, "encoder")
total_loss = 0.0
total_fine = 0.0
total_coarse = 0.0
total_entropy = 0.0
entropy_count = 0
count = 0
max_samples = cfg.get("eval", {}).get("max_samples", 1000)
eval_mode = "coarse_only" if cfg["model"].get("coarse_only", False) else "coarse_fine"
attn_samples_saved = 0
max_attn_saves = 10 # save attention maps for first 10 eval batches
for batch in val_loader:
if count >= max_samples:
break
batch = {
k: v.to(device, non_blocking=True) if isinstance(v, torch.Tensor) else v
for k, v in batch.items()
}
with torch.amp.autocast("cuda", dtype=amp_dtype, enabled=use_amp):
outputs = model(
frames=batch["frames"],
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
loss_mask=batch["loss_mask"],
frame_mask=batch.get("frame_mask"),
mode=eval_mode,
)
bs = batch["frames"].shape[0]
total_loss += outputs["loss"].item() * bs
total_fine += outputs.get("fine_loss", outputs["loss"]).item() * bs
total_coarse += outputs.get("coarse_loss", torch.tensor(0.0)).item() * bs
count += bs
# Attention entropy (foveated model only, sample periodically)
if is_foveated and entropy_count < 50:
try:
frames = batch["frames"]
B, T = frames.shape[:2]
kv_cache, _, mask_flat = raw_model._encode_all_frames(frames)
q_static = raw_model.q_static.expand(B, -1)
# Compute entropy on first frame
frame0_kv = raw_model._extract_frame_kv(kv_cache, mask_flat, B, T, 0)
_, attn_w = raw_model.encoder.query_attend(
q_static, frame0_kv, return_attention=True,
)
total_entropy += compute_attention_entropy(attn_w) * bs
entropy_count += bs
# Save attention maps for a few samples
if save_attn_dir and attn_samples_saved < max_attn_saves:
for t in range(min(T, 4)):
if t > 0:
frame_kv = raw_model._extract_frame_kv(kv_cache, mask_flat, B, T, t)
_, attn_w = raw_model.encoder.query_attend(
q_static, frame_kv, return_attention=True,
)
save_attention_maps(
attn_w, save_attn_dir, step,
sample_idx=0, frame_idx=t,
prefix=f"attn_s{attn_samples_saved:03d}",
)
attn_samples_saved += 1
except Exception:
pass # don't break eval if attention extraction fails
avg_loss = reduce_mean(torch.tensor(total_loss / max(count, 1), device=device)).item()
avg_fine = reduce_mean(torch.tensor(total_fine / max(count, 1), device=device)).item()
avg_coarse = reduce_mean(torch.tensor(total_coarse / max(count, 1), device=device)).item()
avg_entropy = total_entropy / max(entropy_count, 1) if entropy_count > 0 else 0.0
return {
"val_loss": avg_loss,
"val_fine_loss": avg_fine,
"val_coarse_loss": avg_coarse,
"attention_entropy": avg_entropy,
}
# --------------------------------------------------------------------------- #
# Throughput: maximize batch size to fill GPU memory
# --------------------------------------------------------------------------- #
def _maximize_batch_size(cfg: dict, device: torch.device):
"""
Increase batch_size and decrease grad_accum to keep the same effective
batch while processing more samples per forward pass. Larger micro-batches
improve GPU utilization by giving the GPU more parallel work.
The effective batch (batch_size * grad_accum * world_size) stays constant
so learning dynamics are unchanged.
"""
bs = cfg["training"]["batch_size"]
ga = cfg["training"]["grad_accum"]
effective = bs * ga
# Determine max batch size based on available VRAM
if torch.cuda.is_available():
total_gb = torch.cuda.get_device_properties(device).total_memory / 1e9
else:
return
# Conservative VRAM targets per model size (leave headroom for spikes)
llm_path = cfg["model"].get("llm", "")
if "1.7B" in llm_path or "1.7b" in llm_path:
max_bs = 8 # 1.7B needs gradient checkpointing, limited VRAM
elif "360M" in llm_path or "360m" in llm_path:
max_bs = min(effective, 16) # 360M: ~6 GB model+optim, fits bs=16
else:
# 135M or smaller: model is tiny but video frames dominate VRAM.
# DINO processes ALL frames in the batch at once; with bucketed padding
# a batch of 32 × 64 padded frames = 2048 images → OOM on 32GB.
# bs=16 is a safe 2× improvement over bs=8.
max_bs = min(effective, 16)
if max_bs <= bs:
return # already at or above target
new_ga = max(1, effective // max_bs)
new_bs = effective // new_ga # adjust to keep effective exact
if new_bs > bs:
if is_main_process():
print(f" [THROUGHPUT] Batch size: {bs}×{ga}{new_bs}×{new_ga} "
f"(effective={new_bs * new_ga}, was {effective})")
cfg["training"]["batch_size"] = new_bs
cfg["training"]["grad_accum"] = new_ga
# --------------------------------------------------------------------------- #
# Main training loop
# --------------------------------------------------------------------------- #
def train(cfg: dict, args):
rank, world_size, device = setup_distributed()
# ---- Throughput overrides ----
# KEEP IT SIMPLE: only safe code-level opts (TF32, cuDNN, channels_last).
# DO NOT override batch_size, num_workers, or prefetch_factor here.
# bs=32/16 + high workers caused repeated system OOM crashes.
# C1-C3 ran stable for hours at bs=8, workers=2, 43-44 samp/s.
# ---- DPO mode detection ----
is_dpo = cfg.get("loss", {}).get("type") == "dpo"
dpo_beta = cfg.get("loss", {}).get("beta", 0.1)
if is_main_process():
print(f"=== fVLM Training: Stage {cfg['stage']} ===")
print(f" World size: {world_size}")
print(f" Device: {device}")
print(f" Dtype: {cfg['training'].get('dtype', 'float32')}")
if is_dpo:
print(f" Loss type: DPO (beta={dpo_beta})")
# ---- Model ----
model = build_model(cfg, device)
if world_size > 1:
model = DDP(model, device_ids=[rank])
raw_model = model.module if hasattr(model, "module") else model
# ---- Reference model for DPO (frozen copy from same init checkpoint) ----
ref_model = None
if is_dpo:
if is_main_process():
print(" Loading reference model (frozen) ...")
ref_model = build_reference_model(cfg, device)
if is_main_process():
ref_params = sum(p.numel() for p in ref_model.parameters())
print(f" Reference model: {ref_params:,} params (all frozen)")
# ---- Gradient checkpointing (nanochat: trade compute for memory) ----
if cfg["model"].get("gradient_checkpointing", False):
if hasattr(raw_model, "enable_gradient_checkpointing"):
llm_only = cfg["training"].get("compile_encoder", False)
# use_reentrant=False is required for torch.compile compatibility.
# The reentrant checkpoint implementation causes NaN with compile.
use_compile = cfg["training"].get("compile", False)
use_reentrant = not use_compile # non-reentrant when compiling
raw_model.enable_gradient_checkpointing(
llm_only=llm_only, use_reentrant=use_reentrant,
)
if is_main_process():
mode = "LLM only" if llm_only else "LLM + DINO"
reentrant_str = "reentrant" if use_reentrant else "non-reentrant (compile-safe)"
print(f" Gradient checkpointing: {mode}, {reentrant_str}")
# ---- Optimizer (differential LR) ----
param_groups = raw_model.get_param_groups(
lr_backbone=cfg["training"].get("lr_dino", 1e-5),
lr_connector=cfg["training"].get("lr_connector", 1e-4),
)
# Override LLM LR if specified separately from DINO
llm_lr = cfg["training"].get("lr_llm")
if llm_lr is not None:
for g in param_groups:
if g.get("name") == "llm":
g["lr"] = llm_lr
optimizer = torch.optim.AdamW(
param_groups,
weight_decay=cfg["training"].get("weight_decay", 0.01),
fused=True, # nanochat: fused kernel eliminates Python overhead
)
# ---- Schedule ----
grad_accum = cfg["training"].get("grad_accum", 1)
effective_batch = cfg["training"]["batch_size"] * grad_accum * world_size
total_steps = cfg["training"]["total_samples"] // effective_batch
warmup_steps = int(total_steps * cfg["training"].get("warmup_ratio", 0.05))
schedule_type = cfg["training"].get("schedule", "cosine")
if schedule_type == "constant":
scheduler = get_constant_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
)
elif schedule_type == "converging":
# Stage 1: connector 100:1 → 1:1 convergence with backbone
target_lr = cfg["training"].get("target_lr", 3e-5)
scheduler = get_converging_schedule(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps,
target_lr=target_lr,
)
if is_main_process():
print(f" Schedule: converging to target_lr={target_lr} (100:1 → 1:1)")
else:
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps,
)
# ---- Mixed precision ----
dtype_str = cfg["training"].get("dtype", "float32")
use_amp = dtype_str in ("bfloat16", "float16")
amp_dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16}.get(
dtype_str, torch.float32
)
# GradScaler only needed for float16 (bfloat16 doesn't need it)
scaler = torch.amp.GradScaler(enabled=(dtype_str == "float16"))
# ---- Resume ----
start_step = 0
data_position = 0
ckpt_dir = cfg["checkpoint"]["save_dir"]
if cfg["checkpoint"].get("resume") == "auto":
resume_info = load_latest_checkpoint(
ckpt_dir, model, optimizer, scaler, scheduler,
map_location=str(device),
)
if resume_info:
start_step = resume_info["step"]
data_position = resume_info["data_position"]
# ---- Logger ----
logger = TrainingLogger(
project=cfg.get("wandb", {}).get("project", "foveated-vlm"),
config=cfg,
enabled=is_main_process(),
)
# ---- torch.compile ----
# WARNING: torch.compile + gradient_checkpointing = NaN loss (known incompatibility).
if cfg["training"].get("compile", False) and cfg["model"].get("gradient_checkpointing", False):
if is_main_process():
print(" WARNING: compile=true + gradient_checkpointing=true produces NaN loss!")
print(" Disabling torch.compile. Set compile=false in config to suppress this warning.")
cfg["training"]["compile"] = False
if cfg["training"].get("compile", False) and hasattr(torch, "compile"):
compile_mode = cfg["training"].get("compile_mode", "reduce-overhead")
if is_main_process():
print(f" Compiling model with torch.compile ({compile_mode}) ...")
# Compile individual components to avoid graph breaks at boundaries
fullgraph_encoder = cfg["training"].get("fullgraph_encoder", True)
# DINO encoder: fixed 224x224 inputs → dynamic=False, fullgraph for max optimization
raw_model.encoder = torch.compile(
raw_model.encoder, mode=compile_mode, dynamic=False, fullgraph=fullgraph_encoder,
)
# LLM: variable sequence length → dynamic=True
raw_model.llm = torch.compile(raw_model.llm, mode=compile_mode, dynamic=True)
raw_model.dino_to_llm = torch.compile(raw_model.dino_to_llm, mode=compile_mode)
raw_model.llm_to_query = torch.compile(raw_model.llm_to_query, mode=compile_mode)
elif cfg["training"].get("compile_encoder", False) and hasattr(torch, "compile"):
# Selective compile: DINO encoder only. Safe with gradient checkpointing
# because DINO doesn't use grad_ckpt when llm_only=True.
# DINO has fixed 224×224 inputs → dynamic=False for better optimization.
compile_mode = cfg["training"].get("compile_mode", "reduce-overhead")
if is_main_process():
print(f" Compiling DINO encoder only with torch.compile ({compile_mode}) ...")
raw_model.encoder = torch.compile(raw_model.encoder, mode=compile_mode, dynamic=False)
# ---- Val loader ----
val_loader = build_val_loader(cfg)
# ---- Dry run ----
if args.dry_run:
if is_main_process():
if is_dpo:
loader = build_dpo_train_loader(cfg, epoch=0)
else:
loader = build_train_loader(cfg, epoch=0)
batch = next(iter(loader))
print(f"\n Dry run OK:")
for k, v in batch.items():
shape = v.shape if hasattr(v, "shape") else type(v).__name__
print(f" {k:20s} {shape}")
print(f" total_steps = {total_steps}")
print(f" warmup_steps = {warmup_steps}")
print(f" effective_batch = {effective_batch}")
n_params = sum(p.numel() for p in raw_model.parameters())
n_train = sum(p.numel() for p in raw_model.parameters() if p.requires_grad)
print(f" total_params = {n_params:,}")
print(f" trainable_params = {n_train:,}")
cleanup_distributed()
return
if is_main_process():
n_params = sum(p.numel() for p in raw_model.parameters())
print(f" Parameters: {n_params:,}")
print(f" Total steps: {total_steps}")
print(f" Warmup: {warmup_steps}")
print(f" Eff. batch: {effective_batch}")
print(f" Starting at: step={start_step}, samples={data_position}")
print()
# ---- Train ----
max_grad_norm = cfg["training"].get("max_grad_norm", 1.0)
save_every = cfg["checkpoint"].get("save_every_steps", 1000)
eval_every = cfg.get("eval", {}).get("every_steps", 500)
log_every = 10
train_mode = "coarse_only" if cfg["model"].get("coarse_only", False) else "coarse_fine"
if is_main_process() and train_mode != "coarse_fine":
print(f" Train mode: {train_mode}")
global_step = start_step
samples_seen = data_position
epoch = data_position // max(cfg["training"]["total_samples"], 1)
micro_step = 0
model.train()
optimizer.zero_grad(set_to_none=True) # nanochat: faster than setting to zero
# nanochat: disable GC during training — saves ~500ms per collection
gc.collect()
gc.disable()
t0 = time.time()
# Track DPO metrics across micro-steps for logging
dpo_reward_acc_accum = 0.0
dpo_chosen_reward_accum = 0.0
dpo_rejected_reward_accum = 0.0
dpo_micro_count = 0
while global_step < total_steps:
if is_dpo:
train_loader = build_dpo_train_loader(cfg, epoch=epoch)
else:
train_loader = build_train_loader(cfg, epoch=epoch)
for batch in train_loader:
if global_step >= total_steps:
break
# Move to device
batch = {
k: v.to(device, non_blocking=True) if isinstance(v, torch.Tensor) else v
for k, v in batch.items()
}
# Gradient accumulation: skip DDP sync on non-final micro-steps
is_accum = ((micro_step + 1) % grad_accum != 0)
sync_ctx = model.no_sync() if (world_size > 1 and is_accum) else nullcontext()
with sync_ctx:
try:
if is_dpo:
# ---- DPO forward pass ----
with torch.amp.autocast("cuda", dtype=amp_dtype, enabled=use_amp):
# Policy model forward
policy_out = raw_model.forward_dpo(
frames=batch["frames"],
chosen_input_ids=batch["chosen_input_ids"],
chosen_attention_mask=batch["chosen_attention_mask"],
chosen_loss_mask=batch["chosen_loss_mask"],
rejected_input_ids=batch["rejected_input_ids"],
rejected_attention_mask=batch["rejected_attention_mask"],
rejected_loss_mask=batch["rejected_loss_mask"],
frame_mask=batch.get("frame_mask"),
)
# Reference model forward (frozen, no grad)
with torch.no_grad():
ref_out = ref_model.forward_dpo(
frames=batch["frames"],
chosen_input_ids=batch["chosen_input_ids"],
chosen_attention_mask=batch["chosen_attention_mask"],
chosen_loss_mask=batch["chosen_loss_mask"],
rejected_input_ids=batch["rejected_input_ids"],
rejected_attention_mask=batch["rejected_attention_mask"],
rejected_loss_mask=batch["rejected_loss_mask"],
frame_mask=batch.get("frame_mask"),
)
# Compute DPO loss
dpo_result = compute_dpo_loss(
policy_chosen_logps=policy_out["chosen_logps"],
policy_rejected_logps=policy_out["rejected_logps"],
ref_chosen_logps=ref_out["chosen_logps"],
ref_rejected_logps=ref_out["rejected_logps"],
beta=dpo_beta,
)
loss = dpo_result["loss"] / grad_accum
# Store outputs for logging (mimic SFT outputs dict)
outputs = {
"loss": dpo_result["loss"],
"fine_loss": dpo_result["loss"], # alias for logger
"coarse_loss": torch.tensor(0.0, device=device),
"reward_accuracy": dpo_result["reward_accuracy"],
"chosen_reward": dpo_result["chosen_reward"].mean().item(),
"rejected_reward": dpo_result["rejected_reward"].mean().item(),
}
# Accumulate DPO metrics for logging at optimizer step
dpo_reward_acc_accum += dpo_result["reward_accuracy"]
dpo_chosen_reward_accum += dpo_result["chosen_reward"].mean().item()
dpo_rejected_reward_accum += dpo_result["rejected_reward"].mean().item()
dpo_micro_count += 1
scaler.scale(loss).backward()
else:
# ---- Standard SFT forward pass ----
with torch.amp.autocast("cuda", dtype=amp_dtype, enabled=use_amp):
outputs = model(
frames=batch["frames"],
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
loss_mask=batch["loss_mask"],
frame_mask=batch.get("frame_mask"),
mode=train_mode,
)
loss = outputs["loss"] / grad_accum
scaler.scale(loss).backward()
except torch.cuda.OutOfMemoryError:
# Rare: batch with too many real frames. Skip and continue.
if is_main_process():
n_real = batch.get("frame_mask", batch["frames"]).sum().item()
print(f" [OOM] Skipping batch at step {global_step} "
f"(n_real={n_real}). Clearing cache.")
torch.cuda.empty_cache()
optimizer.zero_grad(set_to_none=True)
micro_step = 0 # reset accumulation
dpo_reward_acc_accum = 0.0
dpo_chosen_reward_accum = 0.0
dpo_rejected_reward_accum = 0.0
dpo_micro_count = 0
continue
samples_seen += batch["frames"].shape[0] * world_size
micro_step += 1
# Skip optimizer step if still accumulating
if is_accum:
continue
# ---- Optimizer step ----
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), max_grad_norm,
)
scaler.step(optimizer)
scaler.update()
scheduler.step()
optimizer.zero_grad(set_to_none=True) # nanochat: faster
global_step += 1
# ---- Logging ----
if is_main_process() and global_step % log_every == 0:
elapsed = time.time() - t0
samples_per_sec = samples_seen / max(elapsed, 1e-6)
lr_groups = {g.get("name", "default"): g["lr"]
for g in optimizer.param_groups}
if is_dpo and dpo_micro_count > 0:
# DPO-specific logging
avg_reward_acc = dpo_reward_acc_accum / dpo_micro_count
avg_chosen_reward = dpo_chosen_reward_accum / dpo_micro_count
avg_rejected_reward = dpo_rejected_reward_accum / dpo_micro_count
reward_margin = avg_chosen_reward - avg_rejected_reward
print(
f" step {global_step:6d} | dpo_loss {outputs['loss'].item():.4f} | "
f"rew_acc {avg_reward_acc:.3f} | margin {reward_margin:.3f} | "
f"lr {scheduler.get_last_lr()[0]:.2e} | "
f"gnorm {(grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm):.2f} | "
f"{samples_per_sec:.0f} samp/s",
flush=True,
)
# Log to wandb/CSV via logger (use fine_loss slot for DPO loss)
logger.log_step(
step=global_step,
loss=outputs["loss"].item(),
fine_loss=outputs["loss"].item(),
coarse_loss=0.0,
lr=scheduler.get_last_lr()[0],
grad_norm=grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm,
samples_seen=samples_seen,
samples_per_sec=samples_per_sec,
lr_groups=lr_groups,
)
# Log DPO-specific metrics to wandb
try:
import wandb
if wandb.run is not None:
wandb.log({
"dpo/reward_accuracy": avg_reward_acc,
"dpo/chosen_reward": avg_chosen_reward,
"dpo/rejected_reward": avg_rejected_reward,
"dpo/reward_margin": reward_margin,
}, step=global_step)
except Exception:
pass
# Reset DPO accumulators
dpo_reward_acc_accum = 0.0
dpo_chosen_reward_accum = 0.0
dpo_rejected_reward_accum = 0.0
dpo_micro_count = 0
else:
# Standard SFT logging
logger.log_step(
step=global_step,
loss=outputs["loss"].item(),
fine_loss=outputs["fine_loss"].item(),
coarse_loss=outputs["coarse_loss"].item(),
lr=scheduler.get_last_lr()[0],
grad_norm=grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm,
samples_seen=samples_seen,
samples_per_sec=samples_per_sec,
lr_groups=lr_groups,
)
# ---- Evaluation ----
if val_loader is not None and global_step % eval_every == 0:
attn_dir = os.path.join(ckpt_dir, "attention_maps") if is_main_process() else None
val_result = evaluate(
model, val_loader, device, amp_dtype, use_amp, cfg,
save_attn_dir=attn_dir, step=global_step,
)
if is_main_process():
logger.log_eval(
step=global_step,
val_loss=val_result["val_loss"],
val_fine_loss=val_result["val_fine_loss"],
val_coarse_loss=val_result["val_coarse_loss"],
attention_entropy=val_result["attention_entropy"],
)
model.train()
# ---- Checkpoint ----
if global_step % save_every == 0:
metric = None
if val_loader is not None and global_step % eval_every != 0:
val_result = evaluate(model, val_loader, device, amp_dtype, use_amp, cfg)
metric = val_result["val_loss"]
model.train()
elif val_loader is not None:
metric = val_result["val_loss"] # reuse from eval above
else:
# No val_loader — use train loss as metric (pretraining style)
metric = outputs["loss"].item() if isinstance(outputs["loss"], torch.Tensor) else outputs["loss"]
save_checkpoint(
model=model,
optimizer=optimizer,
scaler=scaler,
scheduler=scheduler,
step=global_step,
data_position=samples_seen,
save_dir=ckpt_dir,
metric_value=metric,
config=cfg,
)
epoch += 1
# ---- Final checkpoint ----
save_checkpoint(
model=model, optimizer=optimizer, scaler=scaler, scheduler=scheduler,
step=global_step, data_position=samples_seen, save_dir=ckpt_dir,
config=cfg,
)
if is_main_process():
elapsed = time.time() - t0
final_loss = outputs["loss"].item() if isinstance(outputs["loss"], torch.Tensor) else outputs["loss"]
logger.save_run_summary(final_loss=final_loss, total_samples=samples_seen)
logger.finish()
print(f"\n Training complete: {global_step} steps, "
f"{samples_seen:,} samples, {elapsed/3600:.1f}h")
cleanup_distributed()
if __name__ == "__main__":
args = parse_args()
cfg = load_config(args.config)
cfg["_config_path"] = args.config
train(cfg, args)