| | |
| | """ |
| | Unified training script for foveated VLM (all stages). |
| | |
| | Usage: |
| | # Single GPU |
| | python train.py --config configs/stage1_135M.yaml |
| | |
| | # Multi-GPU (2xA100) |
| | torchrun --nproc_per_node=2 train.py --config configs/stage1_135M.yaml |
| | |
| | # Dry run (verify config, dataloaders, shapes) |
| | python train.py --config configs/stage1_135M.yaml --dry-run |
| | """ |
| |
|
| | import argparse |
| | import gc |
| | import os |
| | import sys |
| | import time |
| | from contextlib import nullcontext |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import yaml |
| | from torch.nn.parallel import DistributedDataParallel as DDP |
| |
|
| | |
| | torch.set_float32_matmul_precision("high") |
| | torch.backends.cudnn.benchmark = True |
| | torch.backends.cudnn.allow_tf32 = True |
| | torch.backends.cuda.matmul.allow_tf32 = True |
| |
|
| | |
| | os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") |
| |
|
| | |
| |
|
| | from model import FoveatedVLM |
| | from data import make_dataloader, make_dynamic_dataloader, create_dpo_webdataset |
| | from collate import collate_dpo |
| | from text_interleave import InterleavedDataLoader |
| | from distributed import ( |
| | setup_distributed, |
| | cleanup_distributed, |
| | is_main_process, |
| | get_rank, |
| | reduce_mean, |
| | ) |
| | from checkpoint import save_checkpoint, load_latest_checkpoint |
| | from schedule import get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup, get_converging_schedule |
| | from logger import TrainingLogger |
| | from attention_viz import compute_attention_entropy, save_attention_maps |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def parse_args(): |
| | p = argparse.ArgumentParser(description="fVLM training") |
| | p.add_argument("--config", required=True, help="YAML config path") |
| | p.add_argument("--dry-run", action="store_true", |
| | help="Parse config, build model & dataloader, print shapes, exit") |
| | return p.parse_args() |
| |
|
| |
|
| | def load_config(path: str) -> dict: |
| | with open(path) as f: |
| | cfg = yaml.safe_load(f) |
| | return cfg |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def build_model(cfg: dict, device: torch.device): |
| | model = FoveatedVLM( |
| | llm_name=cfg["model"]["llm"], |
| | dino_name=cfg["model"]["dino"], |
| | query_dim=cfg["model"].get("query_dim", 384), |
| | visual_scale=cfg["model"].get("visual_scale", 0.14), |
| | lambda_coarse=cfg["model"].get("lambda_coarse", 0.0), |
| | deep_query=cfg["model"].get("deep_query", True), |
| | use_fused_ce=cfg["model"].get("use_fused_ce", False), |
| | ) |
| |
|
| | |
| | init_from = cfg["model"].get("init_from") |
| | if init_from and os.path.exists(init_from): |
| | ckpt = torch.load(init_from, map_location="cpu", weights_only=False) |
| | model.load_state_dict(ckpt["model_state_dict"]) |
| | if is_main_process(): |
| | print(f" Loaded weights from {init_from}") |
| |
|
| | if cfg["model"].get("gradient_checkpointing", False): |
| | model.enable_gradient_checkpointing() |
| |
|
| | model = model.to(device) |
| |
|
| | |
| | if hasattr(model, 'encoder') and hasattr(model.encoder, 'dino'): |
| | model.encoder.dino = model.encoder.dino.to(memory_format=torch.channels_last) |
| |
|
| | |
| | dino_module = getattr(model, 'encoder', None) |
| | if dino_module is not None: |
| | dino_module = getattr(dino_module, 'dino', None) |
| | if dino_module is None: |
| | dino_module = getattr(model, 'dino', None) |
| |
|
| | if cfg["model"].get("freeze_dino", False) and dino_module is not None: |
| | for p in dino_module.parameters(): |
| | p.requires_grad = False |
| | if is_main_process(): |
| | print(" Frozen: DINO encoder") |
| |
|
| | if cfg["model"].get("freeze_llm", False): |
| | for p in model.llm.parameters(): |
| | p.requires_grad = False |
| | if is_main_process(): |
| | print(" Frozen: LLM backbone") |
| |
|
| | return model |
| |
|
| |
|
| | def _get_tokenizer(cfg: dict): |
| | """Lazy-load the tokenizer for on-the-fly tokenization of raw captions.""" |
| | from transformers import AutoTokenizer |
| | tok = AutoTokenizer.from_pretrained(cfg["model"]["llm"]) |
| | if tok.pad_token is None: |
| | tok.pad_token = tok.eos_token |
| | return tok |
| |
|
| |
|
| | def build_train_loader(cfg: dict, epoch: int = 0): |
| | """Build the training dataloader (vision + optional text interleave).""" |
| | stage = cfg.get("stage", 1) |
| | tokenizer = _get_tokenizer(cfg) |
| |
|
| | use_dynamic = cfg["training"].get("dynamic_batching", False) |
| |
|
| | if use_dynamic: |
| | vision_loader = make_dynamic_dataloader( |
| | shard_pattern=cfg["data"]["train_shards"], |
| | max_total_frames=cfg["training"].get("max_total_frames", 512), |
| | max_batch_size=cfg["training"].get("max_batch_size", 64), |
| | max_frames=cfg["data"].get("max_frames", 64), |
| | min_frames=cfg["data"].get("min_frames", 0), |
| | shuffle=True, |
| | seed=cfg["training"].get("seed", 42), |
| | epoch=epoch, |
| | num_workers=cfg["data"].get("num_workers", 12), |
| | prefetch_factor=cfg["data"].get("prefetch_factor", 8), |
| | tokenizer=tokenizer, |
| | stage=stage, |
| | replicate_image_frames=cfg["data"].get("replicate_image_frames", 1), |
| | ) |
| | else: |
| | vision_loader = make_dataloader( |
| | shard_pattern=cfg["data"]["train_shards"], |
| | batch_size=cfg["training"]["batch_size"], |
| | max_frames=cfg["data"].get("max_frames", 64), |
| | min_frames=cfg["data"].get("min_frames", 0), |
| | shuffle=True, |
| | seed=cfg["training"].get("seed", 42), |
| | epoch=epoch, |
| | num_workers=cfg["data"].get("num_workers", 12), |
| | prefetch_factor=cfg["data"].get("prefetch_factor", 8), |
| | tokenizer=tokenizer, |
| | stage=stage, |
| | replicate_image_frames=cfg["data"].get("replicate_image_frames", 1), |
| | ) |
| |
|
| | text_ratio = cfg["data"].get("text_ratio", 0.0) |
| | if text_ratio > 0 and cfg["data"].get("text_shards"): |
| | text_loader = make_dataloader( |
| | shard_pattern=cfg["data"]["text_shards"], |
| | batch_size=cfg["training"]["batch_size"], |
| | max_frames=1, |
| | shuffle=True, |
| | seed=cfg["training"].get("seed", 42), |
| | epoch=epoch, |
| | num_workers=max(1, cfg["data"].get("num_workers", 12) // 2), |
| | prefetch_factor=cfg["data"].get("prefetch_factor", 8), |
| | tokenizer=tokenizer, |
| | stage=stage, |
| | ) |
| | return InterleavedDataLoader( |
| | vision_loader=vision_loader, |
| | text_loader=text_loader, |
| | text_ratio=text_ratio, |
| | seed=cfg["training"].get("seed", 42) + epoch, |
| | ) |
| |
|
| | return vision_loader |
| |
|
| |
|
| | def build_dpo_train_loader(cfg: dict, epoch: int = 0): |
| | """Build the training dataloader for DPO (preference) data.""" |
| | tokenizer = _get_tokenizer(cfg) |
| |
|
| | dataset = create_dpo_webdataset( |
| | shard_pattern=cfg["data"]["train_shards"], |
| | tokenizer=tokenizer, |
| | max_frames=cfg["data"].get("max_frames", 64), |
| | shuffle=True, |
| | seed=cfg["training"].get("seed", 42), |
| | epoch=epoch, |
| | num_workers=cfg["data"].get("num_workers", 2), |
| | replicate_image_frames=cfg["data"].get("replicate_image_frames", 1), |
| | ) |
| |
|
| | loader = torch.utils.data.DataLoader( |
| | dataset, |
| | batch_size=cfg["training"]["batch_size"], |
| | num_workers=cfg["data"].get("num_workers", 2), |
| | collate_fn=collate_dpo, |
| | pin_memory=True, |
| | prefetch_factor=cfg["data"].get("prefetch_factor", 2), |
| | persistent_workers=cfg["data"].get("num_workers", 2) > 0, |
| | ) |
| | return loader |
| |
|
| |
|
| | def build_reference_model(cfg: dict, device: torch.device): |
| | """ |
| | Build a frozen reference model for DPO training. |
| | |
| | The reference model is a copy of the policy model loaded from the same |
| | init_from checkpoint (the Stage 2 best). All parameters are frozen |
| | and the model is set to eval mode. |
| | """ |
| | ref_model = build_model(cfg, device) |
| | ref_model.eval() |
| | for p in ref_model.parameters(): |
| | p.requires_grad = False |
| | return ref_model |
| |
|
| |
|
| | def compute_dpo_loss( |
| | policy_chosen_logps: torch.Tensor, |
| | policy_rejected_logps: torch.Tensor, |
| | ref_chosen_logps: torch.Tensor, |
| | ref_rejected_logps: torch.Tensor, |
| | beta: float = 0.1, |
| | ) -> dict: |
| | """ |
| | Compute the DPO loss and reward accuracy. |
| | |
| | DPO loss = -log_sigmoid(β * ((π_chosen - π_ref_chosen) - (π_rejected - π_ref_rejected))) |
| | |
| | Parameters |
| | ---------- |
| | policy_chosen_logps : [B] log-probs from policy on chosen |
| | policy_rejected_logps : [B] log-probs from policy on rejected |
| | ref_chosen_logps : [B] log-probs from reference on chosen |
| | ref_rejected_logps : [B] log-probs from reference on rejected |
| | beta : float DPO temperature |
| | |
| | Returns |
| | ------- |
| | dict with keys: |
| | loss : scalar DPO loss |
| | reward_accuracy : float, fraction where chosen is preferred |
| | chosen_reward : [B] implicit reward for chosen |
| | rejected_reward : [B] implicit reward for rejected |
| | """ |
| | |
| | chosen_reward = beta * (policy_chosen_logps - ref_chosen_logps) |
| | rejected_reward = beta * (policy_rejected_logps - ref_rejected_logps) |
| |
|
| | |
| | logits = chosen_reward - rejected_reward |
| | loss = -torch.nn.functional.logsigmoid(logits).mean() |
| |
|
| | |
| | reward_accuracy = (logits > 0).float().mean().item() |
| |
|
| | return { |
| | "loss": loss, |
| | "reward_accuracy": reward_accuracy, |
| | "chosen_reward": chosen_reward, |
| | "rejected_reward": rejected_reward, |
| | } |
| |
|
| |
|
| | def build_val_loader(cfg: dict): |
| | val_shards = cfg["data"].get("val_shards") |
| | if not val_shards: |
| | return None |
| | stage = cfg.get("stage", 1) |
| | tokenizer = _get_tokenizer(cfg) |
| | return make_dataloader( |
| | shard_pattern=val_shards, |
| | batch_size=cfg["training"]["batch_size"], |
| | max_frames=cfg["data"].get("max_frames", 64), |
| | shuffle=False, |
| | num_workers=0, |
| | tokenizer=tokenizer, |
| | stage=stage, |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @torch.no_grad() |
| | def evaluate(model, val_loader, device, amp_dtype, use_amp, cfg, |
| | save_attn_dir=None, step=0): |
| | """Run validation and return dict of average losses + attention entropy.""" |
| | model.eval() |
| | raw_model = model.module if hasattr(model, "module") else model |
| | is_foveated = hasattr(raw_model, "encoder") |
| |
|
| | total_loss = 0.0 |
| | total_fine = 0.0 |
| | total_coarse = 0.0 |
| | total_entropy = 0.0 |
| | entropy_count = 0 |
| | count = 0 |
| | max_samples = cfg.get("eval", {}).get("max_samples", 1000) |
| | eval_mode = "coarse_only" if cfg["model"].get("coarse_only", False) else "coarse_fine" |
| | attn_samples_saved = 0 |
| | max_attn_saves = 10 |
| |
|
| | for batch in val_loader: |
| | if count >= max_samples: |
| | break |
| |
|
| | batch = { |
| | k: v.to(device, non_blocking=True) if isinstance(v, torch.Tensor) else v |
| | for k, v in batch.items() |
| | } |
| |
|
| | with torch.amp.autocast("cuda", dtype=amp_dtype, enabled=use_amp): |
| | outputs = model( |
| | frames=batch["frames"], |
| | input_ids=batch["input_ids"], |
| | attention_mask=batch["attention_mask"], |
| | loss_mask=batch["loss_mask"], |
| | frame_mask=batch.get("frame_mask"), |
| | mode=eval_mode, |
| | ) |
| |
|
| | bs = batch["frames"].shape[0] |
| | total_loss += outputs["loss"].item() * bs |
| | total_fine += outputs.get("fine_loss", outputs["loss"]).item() * bs |
| | total_coarse += outputs.get("coarse_loss", torch.tensor(0.0)).item() * bs |
| | count += bs |
| |
|
| | |
| | if is_foveated and entropy_count < 50: |
| | try: |
| | frames = batch["frames"] |
| | B, T = frames.shape[:2] |
| | kv_cache, _, mask_flat = raw_model._encode_all_frames(frames) |
| | q_static = raw_model.q_static.expand(B, -1) |
| | |
| | frame0_kv = raw_model._extract_frame_kv(kv_cache, mask_flat, B, T, 0) |
| | _, attn_w = raw_model.encoder.query_attend( |
| | q_static, frame0_kv, return_attention=True, |
| | ) |
| | total_entropy += compute_attention_entropy(attn_w) * bs |
| | entropy_count += bs |
| |
|
| | |
| | if save_attn_dir and attn_samples_saved < max_attn_saves: |
| | for t in range(min(T, 4)): |
| | if t > 0: |
| | frame_kv = raw_model._extract_frame_kv(kv_cache, mask_flat, B, T, t) |
| | _, attn_w = raw_model.encoder.query_attend( |
| | q_static, frame_kv, return_attention=True, |
| | ) |
| | save_attention_maps( |
| | attn_w, save_attn_dir, step, |
| | sample_idx=0, frame_idx=t, |
| | prefix=f"attn_s{attn_samples_saved:03d}", |
| | ) |
| | attn_samples_saved += 1 |
| | except Exception: |
| | pass |
| |
|
| | avg_loss = reduce_mean(torch.tensor(total_loss / max(count, 1), device=device)).item() |
| | avg_fine = reduce_mean(torch.tensor(total_fine / max(count, 1), device=device)).item() |
| | avg_coarse = reduce_mean(torch.tensor(total_coarse / max(count, 1), device=device)).item() |
| | avg_entropy = total_entropy / max(entropy_count, 1) if entropy_count > 0 else 0.0 |
| |
|
| | return { |
| | "val_loss": avg_loss, |
| | "val_fine_loss": avg_fine, |
| | "val_coarse_loss": avg_coarse, |
| | "attention_entropy": avg_entropy, |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def _maximize_batch_size(cfg: dict, device: torch.device): |
| | """ |
| | Increase batch_size and decrease grad_accum to keep the same effective |
| | batch while processing more samples per forward pass. Larger micro-batches |
| | improve GPU utilization by giving the GPU more parallel work. |
| | |
| | The effective batch (batch_size * grad_accum * world_size) stays constant |
| | so learning dynamics are unchanged. |
| | """ |
| | bs = cfg["training"]["batch_size"] |
| | ga = cfg["training"]["grad_accum"] |
| | effective = bs * ga |
| |
|
| | |
| | if torch.cuda.is_available(): |
| | total_gb = torch.cuda.get_device_properties(device).total_memory / 1e9 |
| | else: |
| | return |
| |
|
| | |
| | llm_path = cfg["model"].get("llm", "") |
| | if "1.7B" in llm_path or "1.7b" in llm_path: |
| | max_bs = 8 |
| | elif "360M" in llm_path or "360m" in llm_path: |
| | max_bs = min(effective, 16) |
| | else: |
| | |
| | |
| | |
| | |
| | max_bs = min(effective, 16) |
| |
|
| | if max_bs <= bs: |
| | return |
| |
|
| | new_ga = max(1, effective // max_bs) |
| | new_bs = effective // new_ga |
| |
|
| | if new_bs > bs: |
| | if is_main_process(): |
| | print(f" [THROUGHPUT] Batch size: {bs}×{ga} → {new_bs}×{new_ga} " |
| | f"(effective={new_bs * new_ga}, was {effective})") |
| | cfg["training"]["batch_size"] = new_bs |
| | cfg["training"]["grad_accum"] = new_ga |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def train(cfg: dict, args): |
| | rank, world_size, device = setup_distributed() |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | is_dpo = cfg.get("loss", {}).get("type") == "dpo" |
| | dpo_beta = cfg.get("loss", {}).get("beta", 0.1) |
| |
|
| | if is_main_process(): |
| | print(f"=== fVLM Training: Stage {cfg['stage']} ===") |
| | print(f" World size: {world_size}") |
| | print(f" Device: {device}") |
| | print(f" Dtype: {cfg['training'].get('dtype', 'float32')}") |
| | if is_dpo: |
| | print(f" Loss type: DPO (beta={dpo_beta})") |
| |
|
| | |
| | model = build_model(cfg, device) |
| |
|
| | if world_size > 1: |
| | model = DDP(model, device_ids=[rank]) |
| |
|
| | raw_model = model.module if hasattr(model, "module") else model |
| |
|
| | |
| | ref_model = None |
| | if is_dpo: |
| | if is_main_process(): |
| | print(" Loading reference model (frozen) ...") |
| | ref_model = build_reference_model(cfg, device) |
| | if is_main_process(): |
| | ref_params = sum(p.numel() for p in ref_model.parameters()) |
| | print(f" Reference model: {ref_params:,} params (all frozen)") |
| |
|
| | |
| | if cfg["model"].get("gradient_checkpointing", False): |
| | if hasattr(raw_model, "enable_gradient_checkpointing"): |
| | llm_only = cfg["training"].get("compile_encoder", False) |
| | |
| | |
| | use_compile = cfg["training"].get("compile", False) |
| | use_reentrant = not use_compile |
| | raw_model.enable_gradient_checkpointing( |
| | llm_only=llm_only, use_reentrant=use_reentrant, |
| | ) |
| | if is_main_process(): |
| | mode = "LLM only" if llm_only else "LLM + DINO" |
| | reentrant_str = "reentrant" if use_reentrant else "non-reentrant (compile-safe)" |
| | print(f" Gradient checkpointing: {mode}, {reentrant_str}") |
| |
|
| | |
| | param_groups = raw_model.get_param_groups( |
| | lr_backbone=cfg["training"].get("lr_dino", 1e-5), |
| | lr_connector=cfg["training"].get("lr_connector", 1e-4), |
| | ) |
| | |
| | llm_lr = cfg["training"].get("lr_llm") |
| | if llm_lr is not None: |
| | for g in param_groups: |
| | if g.get("name") == "llm": |
| | g["lr"] = llm_lr |
| |
|
| | optimizer = torch.optim.AdamW( |
| | param_groups, |
| | weight_decay=cfg["training"].get("weight_decay", 0.01), |
| | fused=True, |
| | ) |
| |
|
| | |
| | grad_accum = cfg["training"].get("grad_accum", 1) |
| | effective_batch = cfg["training"]["batch_size"] * grad_accum * world_size |
| | total_steps = cfg["training"]["total_samples"] // effective_batch |
| | warmup_steps = int(total_steps * cfg["training"].get("warmup_ratio", 0.05)) |
| |
|
| | schedule_type = cfg["training"].get("schedule", "cosine") |
| | if schedule_type == "constant": |
| | scheduler = get_constant_schedule_with_warmup( |
| | optimizer, |
| | num_warmup_steps=warmup_steps, |
| | ) |
| | elif schedule_type == "converging": |
| | |
| | target_lr = cfg["training"].get("target_lr", 3e-5) |
| | scheduler = get_converging_schedule( |
| | optimizer, |
| | num_warmup_steps=warmup_steps, |
| | num_training_steps=total_steps, |
| | target_lr=target_lr, |
| | ) |
| | if is_main_process(): |
| | print(f" Schedule: converging to target_lr={target_lr} (100:1 → 1:1)") |
| | else: |
| | scheduler = get_cosine_schedule_with_warmup( |
| | optimizer, |
| | num_warmup_steps=warmup_steps, |
| | num_training_steps=total_steps, |
| | ) |
| |
|
| | |
| | dtype_str = cfg["training"].get("dtype", "float32") |
| | use_amp = dtype_str in ("bfloat16", "float16") |
| | amp_dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16}.get( |
| | dtype_str, torch.float32 |
| | ) |
| | |
| | scaler = torch.amp.GradScaler(enabled=(dtype_str == "float16")) |
| |
|
| | |
| | start_step = 0 |
| | data_position = 0 |
| | ckpt_dir = cfg["checkpoint"]["save_dir"] |
| |
|
| | if cfg["checkpoint"].get("resume") == "auto": |
| | resume_info = load_latest_checkpoint( |
| | ckpt_dir, model, optimizer, scaler, scheduler, |
| | map_location=str(device), |
| | ) |
| | if resume_info: |
| | start_step = resume_info["step"] |
| | data_position = resume_info["data_position"] |
| |
|
| | |
| | logger = TrainingLogger( |
| | project=cfg.get("wandb", {}).get("project", "foveated-vlm"), |
| | config=cfg, |
| | enabled=is_main_process(), |
| | ) |
| |
|
| | |
| | |
| | if cfg["training"].get("compile", False) and cfg["model"].get("gradient_checkpointing", False): |
| | if is_main_process(): |
| | print(" WARNING: compile=true + gradient_checkpointing=true produces NaN loss!") |
| | print(" Disabling torch.compile. Set compile=false in config to suppress this warning.") |
| | cfg["training"]["compile"] = False |
| |
|
| | if cfg["training"].get("compile", False) and hasattr(torch, "compile"): |
| | compile_mode = cfg["training"].get("compile_mode", "reduce-overhead") |
| | if is_main_process(): |
| | print(f" Compiling model with torch.compile ({compile_mode}) ...") |
| | |
| | fullgraph_encoder = cfg["training"].get("fullgraph_encoder", True) |
| | |
| | raw_model.encoder = torch.compile( |
| | raw_model.encoder, mode=compile_mode, dynamic=False, fullgraph=fullgraph_encoder, |
| | ) |
| | |
| | raw_model.llm = torch.compile(raw_model.llm, mode=compile_mode, dynamic=True) |
| | raw_model.dino_to_llm = torch.compile(raw_model.dino_to_llm, mode=compile_mode) |
| | raw_model.llm_to_query = torch.compile(raw_model.llm_to_query, mode=compile_mode) |
| | elif cfg["training"].get("compile_encoder", False) and hasattr(torch, "compile"): |
| | |
| | |
| | |
| | compile_mode = cfg["training"].get("compile_mode", "reduce-overhead") |
| | if is_main_process(): |
| | print(f" Compiling DINO encoder only with torch.compile ({compile_mode}) ...") |
| | raw_model.encoder = torch.compile(raw_model.encoder, mode=compile_mode, dynamic=False) |
| |
|
| | |
| | val_loader = build_val_loader(cfg) |
| |
|
| | |
| | if args.dry_run: |
| | if is_main_process(): |
| | if is_dpo: |
| | loader = build_dpo_train_loader(cfg, epoch=0) |
| | else: |
| | loader = build_train_loader(cfg, epoch=0) |
| | batch = next(iter(loader)) |
| | print(f"\n Dry run OK:") |
| | for k, v in batch.items(): |
| | shape = v.shape if hasattr(v, "shape") else type(v).__name__ |
| | print(f" {k:20s} {shape}") |
| | print(f" total_steps = {total_steps}") |
| | print(f" warmup_steps = {warmup_steps}") |
| | print(f" effective_batch = {effective_batch}") |
| | n_params = sum(p.numel() for p in raw_model.parameters()) |
| | n_train = sum(p.numel() for p in raw_model.parameters() if p.requires_grad) |
| | print(f" total_params = {n_params:,}") |
| | print(f" trainable_params = {n_train:,}") |
| | cleanup_distributed() |
| | return |
| |
|
| | if is_main_process(): |
| | n_params = sum(p.numel() for p in raw_model.parameters()) |
| | print(f" Parameters: {n_params:,}") |
| | print(f" Total steps: {total_steps}") |
| | print(f" Warmup: {warmup_steps}") |
| | print(f" Eff. batch: {effective_batch}") |
| | print(f" Starting at: step={start_step}, samples={data_position}") |
| | print() |
| |
|
| | |
| | max_grad_norm = cfg["training"].get("max_grad_norm", 1.0) |
| | save_every = cfg["checkpoint"].get("save_every_steps", 1000) |
| | eval_every = cfg.get("eval", {}).get("every_steps", 500) |
| | log_every = 10 |
| | train_mode = "coarse_only" if cfg["model"].get("coarse_only", False) else "coarse_fine" |
| | if is_main_process() and train_mode != "coarse_fine": |
| | print(f" Train mode: {train_mode}") |
| |
|
| | global_step = start_step |
| | samples_seen = data_position |
| | epoch = data_position // max(cfg["training"]["total_samples"], 1) |
| | micro_step = 0 |
| |
|
| | model.train() |
| | optimizer.zero_grad(set_to_none=True) |
| |
|
| | |
| | gc.collect() |
| | gc.disable() |
| |
|
| | t0 = time.time() |
| |
|
| | |
| | dpo_reward_acc_accum = 0.0 |
| | dpo_chosen_reward_accum = 0.0 |
| | dpo_rejected_reward_accum = 0.0 |
| | dpo_micro_count = 0 |
| |
|
| | while global_step < total_steps: |
| | if is_dpo: |
| | train_loader = build_dpo_train_loader(cfg, epoch=epoch) |
| | else: |
| | train_loader = build_train_loader(cfg, epoch=epoch) |
| |
|
| | for batch in train_loader: |
| | if global_step >= total_steps: |
| | break |
| |
|
| | |
| | batch = { |
| | k: v.to(device, non_blocking=True) if isinstance(v, torch.Tensor) else v |
| | for k, v in batch.items() |
| | } |
| |
|
| | |
| | is_accum = ((micro_step + 1) % grad_accum != 0) |
| | sync_ctx = model.no_sync() if (world_size > 1 and is_accum) else nullcontext() |
| |
|
| | with sync_ctx: |
| | try: |
| | if is_dpo: |
| | |
| | with torch.amp.autocast("cuda", dtype=amp_dtype, enabled=use_amp): |
| | |
| | policy_out = raw_model.forward_dpo( |
| | frames=batch["frames"], |
| | chosen_input_ids=batch["chosen_input_ids"], |
| | chosen_attention_mask=batch["chosen_attention_mask"], |
| | chosen_loss_mask=batch["chosen_loss_mask"], |
| | rejected_input_ids=batch["rejected_input_ids"], |
| | rejected_attention_mask=batch["rejected_attention_mask"], |
| | rejected_loss_mask=batch["rejected_loss_mask"], |
| | frame_mask=batch.get("frame_mask"), |
| | ) |
| |
|
| | |
| | with torch.no_grad(): |
| | ref_out = ref_model.forward_dpo( |
| | frames=batch["frames"], |
| | chosen_input_ids=batch["chosen_input_ids"], |
| | chosen_attention_mask=batch["chosen_attention_mask"], |
| | chosen_loss_mask=batch["chosen_loss_mask"], |
| | rejected_input_ids=batch["rejected_input_ids"], |
| | rejected_attention_mask=batch["rejected_attention_mask"], |
| | rejected_loss_mask=batch["rejected_loss_mask"], |
| | frame_mask=batch.get("frame_mask"), |
| | ) |
| |
|
| | |
| | dpo_result = compute_dpo_loss( |
| | policy_chosen_logps=policy_out["chosen_logps"], |
| | policy_rejected_logps=policy_out["rejected_logps"], |
| | ref_chosen_logps=ref_out["chosen_logps"], |
| | ref_rejected_logps=ref_out["rejected_logps"], |
| | beta=dpo_beta, |
| | ) |
| | loss = dpo_result["loss"] / grad_accum |
| |
|
| | |
| | outputs = { |
| | "loss": dpo_result["loss"], |
| | "fine_loss": dpo_result["loss"], |
| | "coarse_loss": torch.tensor(0.0, device=device), |
| | "reward_accuracy": dpo_result["reward_accuracy"], |
| | "chosen_reward": dpo_result["chosen_reward"].mean().item(), |
| | "rejected_reward": dpo_result["rejected_reward"].mean().item(), |
| | } |
| |
|
| | |
| | dpo_reward_acc_accum += dpo_result["reward_accuracy"] |
| | dpo_chosen_reward_accum += dpo_result["chosen_reward"].mean().item() |
| | dpo_rejected_reward_accum += dpo_result["rejected_reward"].mean().item() |
| | dpo_micro_count += 1 |
| |
|
| | scaler.scale(loss).backward() |
| | else: |
| | |
| | with torch.amp.autocast("cuda", dtype=amp_dtype, enabled=use_amp): |
| | outputs = model( |
| | frames=batch["frames"], |
| | input_ids=batch["input_ids"], |
| | attention_mask=batch["attention_mask"], |
| | loss_mask=batch["loss_mask"], |
| | frame_mask=batch.get("frame_mask"), |
| | mode=train_mode, |
| | ) |
| | loss = outputs["loss"] / grad_accum |
| |
|
| | scaler.scale(loss).backward() |
| | except torch.cuda.OutOfMemoryError: |
| | |
| | if is_main_process(): |
| | n_real = batch.get("frame_mask", batch["frames"]).sum().item() |
| | print(f" [OOM] Skipping batch at step {global_step} " |
| | f"(n_real={n_real}). Clearing cache.") |
| | torch.cuda.empty_cache() |
| | optimizer.zero_grad(set_to_none=True) |
| | micro_step = 0 |
| | dpo_reward_acc_accum = 0.0 |
| | dpo_chosen_reward_accum = 0.0 |
| | dpo_rejected_reward_accum = 0.0 |
| | dpo_micro_count = 0 |
| | continue |
| |
|
| | samples_seen += batch["frames"].shape[0] * world_size |
| | micro_step += 1 |
| |
|
| | |
| | if is_accum: |
| | continue |
| |
|
| | |
| | scaler.unscale_(optimizer) |
| | grad_norm = torch.nn.utils.clip_grad_norm_( |
| | model.parameters(), max_grad_norm, |
| | ) |
| | scaler.step(optimizer) |
| | scaler.update() |
| | scheduler.step() |
| | optimizer.zero_grad(set_to_none=True) |
| | global_step += 1 |
| |
|
| | |
| | if is_main_process() and global_step % log_every == 0: |
| | elapsed = time.time() - t0 |
| | samples_per_sec = samples_seen / max(elapsed, 1e-6) |
| | lr_groups = {g.get("name", "default"): g["lr"] |
| | for g in optimizer.param_groups} |
| |
|
| | if is_dpo and dpo_micro_count > 0: |
| | |
| | avg_reward_acc = dpo_reward_acc_accum / dpo_micro_count |
| | avg_chosen_reward = dpo_chosen_reward_accum / dpo_micro_count |
| | avg_rejected_reward = dpo_rejected_reward_accum / dpo_micro_count |
| | reward_margin = avg_chosen_reward - avg_rejected_reward |
| |
|
| | print( |
| | f" step {global_step:6d} | dpo_loss {outputs['loss'].item():.4f} | " |
| | f"rew_acc {avg_reward_acc:.3f} | margin {reward_margin:.3f} | " |
| | f"lr {scheduler.get_last_lr()[0]:.2e} | " |
| | f"gnorm {(grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm):.2f} | " |
| | f"{samples_per_sec:.0f} samp/s", |
| | flush=True, |
| | ) |
| |
|
| | |
| | logger.log_step( |
| | step=global_step, |
| | loss=outputs["loss"].item(), |
| | fine_loss=outputs["loss"].item(), |
| | coarse_loss=0.0, |
| | lr=scheduler.get_last_lr()[0], |
| | grad_norm=grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm, |
| | samples_seen=samples_seen, |
| | samples_per_sec=samples_per_sec, |
| | lr_groups=lr_groups, |
| | ) |
| |
|
| | |
| | try: |
| | import wandb |
| | if wandb.run is not None: |
| | wandb.log({ |
| | "dpo/reward_accuracy": avg_reward_acc, |
| | "dpo/chosen_reward": avg_chosen_reward, |
| | "dpo/rejected_reward": avg_rejected_reward, |
| | "dpo/reward_margin": reward_margin, |
| | }, step=global_step) |
| | except Exception: |
| | pass |
| |
|
| | |
| | dpo_reward_acc_accum = 0.0 |
| | dpo_chosen_reward_accum = 0.0 |
| | dpo_rejected_reward_accum = 0.0 |
| | dpo_micro_count = 0 |
| | else: |
| | |
| | logger.log_step( |
| | step=global_step, |
| | loss=outputs["loss"].item(), |
| | fine_loss=outputs["fine_loss"].item(), |
| | coarse_loss=outputs["coarse_loss"].item(), |
| | lr=scheduler.get_last_lr()[0], |
| | grad_norm=grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm, |
| | samples_seen=samples_seen, |
| | samples_per_sec=samples_per_sec, |
| | lr_groups=lr_groups, |
| | ) |
| |
|
| | |
| | if val_loader is not None and global_step % eval_every == 0: |
| | attn_dir = os.path.join(ckpt_dir, "attention_maps") if is_main_process() else None |
| | val_result = evaluate( |
| | model, val_loader, device, amp_dtype, use_amp, cfg, |
| | save_attn_dir=attn_dir, step=global_step, |
| | ) |
| | if is_main_process(): |
| | logger.log_eval( |
| | step=global_step, |
| | val_loss=val_result["val_loss"], |
| | val_fine_loss=val_result["val_fine_loss"], |
| | val_coarse_loss=val_result["val_coarse_loss"], |
| | attention_entropy=val_result["attention_entropy"], |
| | ) |
| | model.train() |
| |
|
| | |
| | if global_step % save_every == 0: |
| | metric = None |
| | if val_loader is not None and global_step % eval_every != 0: |
| | val_result = evaluate(model, val_loader, device, amp_dtype, use_amp, cfg) |
| | metric = val_result["val_loss"] |
| | model.train() |
| | elif val_loader is not None: |
| | metric = val_result["val_loss"] |
| | else: |
| | |
| | metric = outputs["loss"].item() if isinstance(outputs["loss"], torch.Tensor) else outputs["loss"] |
| | save_checkpoint( |
| | model=model, |
| | optimizer=optimizer, |
| | scaler=scaler, |
| | scheduler=scheduler, |
| | step=global_step, |
| | data_position=samples_seen, |
| | save_dir=ckpt_dir, |
| | metric_value=metric, |
| | config=cfg, |
| | ) |
| |
|
| | epoch += 1 |
| |
|
| | |
| | save_checkpoint( |
| | model=model, optimizer=optimizer, scaler=scaler, scheduler=scheduler, |
| | step=global_step, data_position=samples_seen, save_dir=ckpt_dir, |
| | config=cfg, |
| | ) |
| |
|
| | if is_main_process(): |
| | elapsed = time.time() - t0 |
| | final_loss = outputs["loss"].item() if isinstance(outputs["loss"], torch.Tensor) else outputs["loss"] |
| | logger.save_run_summary(final_loss=final_loss, total_samples=samples_seen) |
| | logger.finish() |
| | print(f"\n Training complete: {global_step} steps, " |
| | f"{samples_seen:,} samples, {elapsed/3600:.1f}h") |
| |
|
| | cleanup_distributed() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | args = parse_args() |
| | cfg = load_config(args.config) |
| | cfg["_config_path"] = args.config |
| | train(cfg, args) |
| |
|