#!/usr/bin/env python3 """RoSA (Robust Sparse Adaptation) behavioral cloning on Lichess games. Three training modes: rosa -- Standard RoSA: LoRA warm-up -> gradient masks -> joint LoRA+sparse retro-sparse -- Retrospective: LoRA warm-up -> masks -> restart sparse-only retro-bottleneck -- Retrospective: LoRA warm-up -> masks -> restart sparse+bottleneck Usage: uv run python scripts/train_rosa.py \ --checkpoint /path/to/checkpoint \ --pgn /path/to/lichess.pgn \ --mode rosa \ --density 0.01 \ --local-checkpoints """ from __future__ import annotations import argparse import gc import math import signal import time from pathlib import Path import numpy as np import torch import torch.nn.functional as F from torch.utils.data import DataLoader from pawn.config import CLMConfig, PAD_TOKEN from pawn.model import PAWNCLM from pawn.adapters.rosa import RoSACLM, RetroBottleneckCLM, generate_gradient_masks from pawn.adapters.sparse import SparseCLM, SparseLinear from pawn.adapters.lora import ATTN_PRESETS, _FFN_TARGETS from pawn.logging import MetricsLogger from pawn.gpu import configure_gpu, apply_gpu_config from pawn.lichess_data import ( compute_legal_indices, prepare_lichess_dataset, LegalMaskBuilder, LegalMaskCollate, LichessDataset, ) def parse_args(): p = argparse.ArgumentParser(description="RoSA BC on Lichess games") p.add_argument("--checkpoint", type=str, required=True, help="Path to PAWN checkpoint") p.add_argument("--pgn", type=str, required=True, help="Path to Lichess PGN file (pre-filtered by Elo)") p.add_argument("--log-dir", type=str, default=None, help="Parent log directory (default: /logs)") p.add_argument("--output-dir", type=str, default=None, help="Explicit output directory (overrides --log-dir)") # Mode p.add_argument("--mode", type=str, required=True, choices=["rosa", "retro-sparse", "retro-bottleneck"], help="Training mode") # LoRA config (used during warm-up in all modes) p.add_argument("--lora-rank", type=int, default=4, help="LoRA rank (default: 4)") p.add_argument("--lora-alpha", type=float, default=None, help="LoRA alpha scaling (default: same as rank)") p.add_argument("--lora-targets", type=str, default="qkvo", choices=["qkvo", "qv", "qkv"], help="Which attention projections to adapt (default: qkvo)") p.add_argument("--lora-ffn", action="store_true", help="Also apply adapters to FFN projections") # Sparse config p.add_argument("--density", type=float, default=0.01, help="Sparse mask density (default: 0.01)") # Mask generation p.add_argument("--warmup-steps", type=int, default=128, help="LoRA-only warm-up steps before mask generation (default: 128)") p.add_argument("--warmup-lr", type=float, default=None, help="Learning rate for warm-up phase (default: same as --lr)") p.add_argument("--mask-samples", type=int, default=32, help="Batches for gradient accumulation during mask generation (default: 32)") p.add_argument("--grad-alpha", type=int, default=2, choices=[1, 2], help="Gradient accumulation exponent: 1=mean, 2=Fisher (default: 2)") # RoSA-specific p.add_argument("--restart-lora", action="store_true", default=True, help="Re-initialize LoRA after mask generation (default: True)") p.add_argument("--no-restart-lora", action="store_false", dest="restart_lora", help="Keep warm-up LoRA weights for joint training") # Bottleneck (retro-bottleneck mode only) p.add_argument("--bottleneck-dim", type=int, default=8, help="Bottleneck adapter dimension (retro-bottleneck only, default: 8)") # Data p.add_argument("--max-games", type=int, default=12_000) p.add_argument("--val-games", type=int, default=2_000) p.add_argument("--min-ply", type=int, default=10) # Training (Phase 3) p.add_argument("--epochs", type=int, default=50) p.add_argument("--batch-size", type=int, default=64) p.add_argument("--lr", type=float, default=3e-4) p.add_argument("--weight-decay", type=float, default=0.0) p.add_argument("--max-grad-norm", type=float, default=1.0) p.add_argument("--warmup-frac", type=float, default=0.05, help="Fraction of Phase 3 steps for LR warmup") p.add_argument("--patience", type=int, default=10, help="Early stopping patience (epochs)") p.add_argument("--val-every", type=int, default=1) # Device / precision p.add_argument("--device", type=str, default="cuda") p.add_argument("--no-amp", action="store_true") p.add_argument("--no-compile", action="store_true") p.add_argument("--sdpa-math", action="store_true", help="Use MATH SDPA backend (workaround for ROCm flash attn + compile)") p.add_argument("--num-workers", type=int, default=8, help="DataLoader workers for legal mask prefetch (default: 8)") ckpt_group = p.add_mutually_exclusive_group(required=True) ckpt_group.add_argument("--hf-repo", type=str, default=None, help="Push checkpoints to this HuggingFace repo") ckpt_group.add_argument("--local-checkpoints", action="store_true", help="Save checkpoints locally only") return p.parse_args() def load_backbone(checkpoint_path: str, device: str) -> PAWNCLM: from pawn.checkpoint import load_backbone_weights state_dict, model_config = load_backbone_weights(checkpoint_path, device) cfg = CLMConfig(**model_config) if model_config else CLMConfig() model = PAWNCLM(cfg).to(device) model.load_state_dict(state_dict) del state_dict gc.collect() model.eval() return model def cosine_warmup_schedule(optimizer, warmup_steps: int, total_steps: int): """Linear warmup then cosine decay to 0.""" def lr_lambda(step): if step < warmup_steps: return step / max(warmup_steps, 1) progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1) return 0.5 * (1.0 + math.cos(math.pi * progress)) return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) def sparse_forward(model, ids, msk, legal_mask, use_amp, device): """Sparse forward: project only loss-masked positions through lm_head.""" with torch.amp.autocast('cuda', dtype=torch.float16, enabled=use_amp): hidden = model.forward_hidden(ids, msk) valid_hidden = hidden[msk] valid_logits = model.project_head(valid_hidden) valid_legal = legal_mask[msk] valid_logits = valid_logits.float() valid_logits.masked_fill_(~valid_legal, float("-inf")) return valid_logits @torch.no_grad() def evaluate(model, dataloader, mask_builder, device, use_amp: bool = False, precomputed_indices: list[torch.Tensor] | None = None): model.eval() total_loss = 0.0 total_top1 = 0.0 total_top5 = 0.0 total_positions = 0 for i, batch in enumerate(dataloader): ids = batch["input_ids"].to(device, non_blocking=True) tgt = batch["targets"].to(device, non_blocking=True) msk = batch["loss_mask"].to(device, non_blocking=True) if precomputed_indices is not None: legal_mask = mask_builder.scatter(precomputed_indices[i], ids.shape[0]) elif "legal_indices" in batch: legal_mask = mask_builder.scatter(batch["legal_indices"], ids.shape[0]) else: legal_mask = mask_builder(batch) valid_logits = sparse_forward(model, ids, msk, legal_mask, use_amp, device) valid_targets = tgt[msk] n_pos = valid_targets.shape[0] if n_pos == 0: continue loss = F.cross_entropy(valid_logits, valid_targets) preds = valid_logits.argmax(dim=-1) top1 = (preds == valid_targets).float().mean().item() top5 = valid_logits.topk(5, dim=-1).indices top5_acc = (top5 == valid_targets.unsqueeze(-1)).any(dim=-1).float().mean().item() total_loss += loss.item() * n_pos total_top1 += top1 * n_pos total_top5 += top5_acc * n_pos total_positions += n_pos if total_positions == 0: return {"loss": 0.0, "top1_accuracy": 0.0, "top5_accuracy": 0.0} return { "loss": total_loss / total_positions, "top1_accuracy": total_top1 / total_positions, "top5_accuracy": total_top5 / total_positions, } # --------------------------------------------------------------------------- # Phase 1: LoRA warm-up # --------------------------------------------------------------------------- def run_warmup(model, train_loader, mask_builder, args, device, use_amp): """Train LoRA-only for warmup_steps steps. Returns step count.""" lr = args.warmup_lr if args.warmup_lr is not None else args.lr lora_params = model.lora_parameters() optimizer = torch.optim.AdamW(lora_params, lr=lr, weight_decay=args.weight_decay) scaler = torch.amp.GradScaler() if use_amp else None model.train() step = 0 total_loss = 0.0 t0 = time.time() print(f"\n=== Phase 1: LoRA warm-up ({args.warmup_steps} steps, lr={lr}) ===") while step < args.warmup_steps: for batch in train_loader: if step >= args.warmup_steps: break ids = batch["input_ids"].to(device, non_blocking=True) tgt = batch["targets"].to(device, non_blocking=True) msk = batch["loss_mask"].to(device, non_blocking=True) if "legal_indices" in batch: legal_mask = mask_builder.scatter(batch["legal_indices"], ids.shape[0]) else: legal_mask = mask_builder(batch) valid_logits = sparse_forward(model, ids, msk, legal_mask, use_amp, device) valid_targets = tgt[msk] if valid_targets.shape[0] == 0: continue loss = F.cross_entropy(valid_logits, valid_targets) optimizer.zero_grad(set_to_none=True) if scaler is not None: scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(lora_params, args.max_grad_norm) scaler.step(optimizer) scaler.update() else: loss.backward() torch.nn.utils.clip_grad_norm_(lora_params, args.max_grad_norm) optimizer.step() total_loss += loss.item() step += 1 if step % 32 == 0 or step == args.warmup_steps: avg = total_loss / step print(f" Warmup step {step}/{args.warmup_steps} | loss={avg:.4f}") dt = time.time() - t0 print(f" Warm-up complete in {dt:.1f}s (avg loss={total_loss / max(step, 1):.4f})") return step # --------------------------------------------------------------------------- # Phase 2: Mask generation # --------------------------------------------------------------------------- def run_mask_generation(model, train_loader, mask_builder, args, device, use_amp): """Generate gradient-based sparse masks. Returns mask dict.""" print(f"\n=== Phase 2: Mask generation (density={args.density}, " f"alpha={args.grad_alpha}, samples={args.mask_samples}) ===") masks = generate_gradient_masks( model, train_loader, mask_builder, density=args.density, alpha=args.grad_alpha, device=device, use_amp=use_amp, max_batches=args.mask_samples, ) # Log mask statistics total_active = 0 total_elements = 0 for key, mask in masks.items(): n_active = mask.sum().item() n_total = mask.numel() total_active += n_active total_elements += n_total print(f" {key}: {n_active:,} / {n_total:,} ({100*n_active/n_total:.2f}%)") print(f" Total: {total_active:,} / {total_elements:,} " f"({100*total_active/total_elements:.2f}%)") return masks # --------------------------------------------------------------------------- # Phase 3: Main training loop # --------------------------------------------------------------------------- def train_loop(model, adapter_params, train_loader, val_loader, mask_builder, val_legal_indices, logger, args, device, use_amp, gpu_cfg, weight_report_fn): """Standard epoch-based training loop for Phase 3.""" from pawn import model as model_module from pawn.checkpoint import save_adapter_checkpoint, push_checkpoint_to_hf # Compile forward_hidden for Phase 3 model.forward_hidden = apply_gpu_config(gpu_cfg, model_module, model.forward_hidden) optimizer = torch.optim.AdamW( adapter_params, lr=args.lr, weight_decay=args.weight_decay, ) total_steps = args.epochs * len(train_loader) warmup_steps = int(args.warmup_frac * total_steps) scheduler = cosine_warmup_schedule(optimizer, warmup_steps, total_steps) scaler = torch.amp.GradScaler() if use_amp else None # Baseline print("\nBaseline (zero/identity adapters):") baseline = evaluate(model, val_loader, mask_builder, device, use_amp=use_amp, precomputed_indices=val_legal_indices) print(f" loss={baseline['loss']:.4f}, top1={baseline['top1_accuracy']:.4%}, " f"top5={baseline['top5_accuracy']:.4%}") logger.log_train(step=0, epoch=-1, train_loss=baseline["loss"], train_top1=baseline["top1_accuracy"], val_loss=baseline["loss"], val_top1=baseline["top1_accuracy"], val_top5=baseline["top5_accuracy"], ) best_val_loss = float("inf") patience_counter = 0 global_step = 0 val_metrics = baseline ckpt_dir = logger.run_dir / "checkpoints" ckpt_dir.mkdir(exist_ok=True) hf_branch = None if args.hf_repo: hf_branch = f"run/{logger.run_dir.name}" _shutdown_requested = False def _graceful_exit(signum, frame): nonlocal _shutdown_requested _shutdown_requested = True signal.signal(signal.SIGTERM, _graceful_exit) signal.signal(signal.SIGINT, _graceful_exit) print(f"\n=== Phase 3: Main training ({args.epochs} epochs, {total_steps} steps) ===") print(f" LR warmup: {warmup_steps} steps, LR: {args.lr}") epoch = -1 for epoch in range(args.epochs): model.train() epoch_loss = 0.0 epoch_top1 = 0.0 epoch_positions = 0 t0 = time.time() for batch in train_loader: ids = batch["input_ids"].to(device, non_blocking=True) tgt = batch["targets"].to(device, non_blocking=True) msk = batch["loss_mask"].to(device, non_blocking=True) if "legal_indices" in batch: legal_mask = mask_builder.scatter(batch["legal_indices"], ids.shape[0]) else: legal_mask = mask_builder(batch) valid_logits = sparse_forward(model, ids, msk, legal_mask, use_amp, device) valid_targets = tgt[msk] loss = F.cross_entropy(valid_logits, valid_targets) optimizer.zero_grad(set_to_none=True) if scaler is not None: scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(adapter_params, args.max_grad_norm) scaler.step(optimizer) scaler.update() else: loss.backward() torch.nn.utils.clip_grad_norm_(adapter_params, args.max_grad_norm) optimizer.step() scheduler.step() with torch.no_grad(): preds = valid_logits.argmax(dim=-1) top1 = (preds == valid_targets).float().mean().item() n_pos = valid_targets.shape[0] epoch_loss += loss.item() * n_pos epoch_top1 += top1 * n_pos epoch_positions += n_pos global_step += 1 dt = time.time() - t0 train_loss = epoch_loss / max(epoch_positions, 1) train_top1 = epoch_top1 / max(epoch_positions, 1) do_val = (epoch % args.val_every == 0) or (epoch == args.epochs - 1) if do_val: val_metrics = evaluate(model, val_loader, mask_builder, device, use_amp=use_amp, precomputed_indices=val_legal_indices) report = weight_report_fn() logger.log_train(step=global_step, epoch=epoch, lr=optimizer.param_groups[0]["lr"], train_loss=train_loss, train_top1=train_top1, val_loss=val_metrics["loss"], val_top1=val_metrics["top1_accuracy"], val_top5=val_metrics["top5_accuracy"], epoch_time_s=dt, **report, ) print(f" Epoch {epoch:3d} | " f"train_loss={train_loss:.4f} train_top1={train_top1:.4%} | " f"val_loss={val_metrics['loss']:.4f} val_top1={val_metrics['top1_accuracy']:.4%} " f"val_top5={val_metrics['top5_accuracy']:.4%} | " f"{dt:.1f}s") if do_val: if val_metrics["loss"] < best_val_loss: best_val_loss = val_metrics["loss"] patience_counter = 0 save_adapter_checkpoint( ckpt_dir / "best", model.adapter_state_dict(), config=vars(args), epoch=epoch, step=global_step, val_metrics=val_metrics, optimizer=optimizer, scheduler=scheduler, scaler=scaler, extra={"best_val_loss": best_val_loss, "patience_counter": patience_counter}, ) if args.hf_repo and hf_branch: try: push_checkpoint_to_hf(ckpt_dir / "best", args.hf_repo, hf_branch, step=global_step) print(f"Pushed to HF: {args.hf_repo}@{hf_branch}") except Exception as e: print(f"WARNING: HF push failed: {e}") else: patience_counter += 1 if patience_counter >= args.patience: print(f"\n Early stopping at epoch {epoch} (patience={args.patience})") break if _shutdown_requested: print("Shutdown requested, saving checkpoint...") break # Save final checkpoint save_adapter_checkpoint( ckpt_dir / "final", model.adapter_state_dict(), config=vars(args), epoch=epoch, step=global_step, val_metrics=val_metrics, optimizer=optimizer, scheduler=scheduler, scaler=scaler, extra={"best_val_loss": best_val_loss, "patience_counter": patience_counter}, ) if args.hf_repo and hf_branch: try: push_checkpoint_to_hf(ckpt_dir / "final", args.hf_repo, hf_branch, step=global_step) print(f"Pushed to HF: {args.hf_repo}@{hf_branch}") except Exception as e: print(f"WARNING: HF push failed: {e}") return best_val_loss # --------------------------------------------------------------------------- # Mode-specific setup # --------------------------------------------------------------------------- def setup_rosa(model, masks, args): """Standard RoSA: apply masks, optionally reinit LoRA, train jointly.""" model.set_masks(masks) if args.restart_lora: model.reinit_lora() params = model.adapter_parameters() n_lora = sum(p.numel() for p in model.lora_parameters()) n_sparse = model.n_active_sparse_params() n_total = sum(p.numel() for p in params) print(f"\nRoSA joint training: {n_total:,} trainable params") print(f" LoRA: {n_lora:,}, Sparse active: {n_sparse:,}") return model, params def _make_sparse_with_masks(masks, args, device): """Reload backbone, create SparseCLM, overwrite random masks with gradient-derived ones.""" backbone = load_backbone(args.checkpoint, device) attn_targets = ATTN_PRESETS[args.lora_targets] sparse_model = SparseCLM( backbone, density=args.density, attn_targets=attn_targets, adapt_ffn=args.lora_ffn, ) # Overwrite random masks with gradient-derived masks for layer_idx in range(len(backbone.layers)): block = backbone.get_block(layer_idx) for proj_name in attn_targets: module = getattr(block.attn, proj_name, None) if isinstance(module, SparseLinear): key = f"layer{layer_idx}.{proj_name}" if key in masks: module.mask.copy_(masks[key]) if args.lora_ffn: for proj_name in _FFN_TARGETS: module = getattr(block.ffn, proj_name, None) if isinstance(module, SparseLinear): key = f"layer{layer_idx}.{proj_name}" if key in masks: module.mask.copy_(masks[key]) return sparse_model def setup_retro_sparse(masks, args, device): """Retrospective sparse-only: reload backbone, apply gradient masks.""" print("\nReloading fresh backbone for retrospective sparse training...") sparse_model = _make_sparse_with_masks(masks, args, device) params = sparse_model.sparse_parameters() n_active = sparse_model.n_active_params() n_total = sum(p.numel() for p in params) print(f"Retro-sparse: {n_active:,} active / {n_total:,} total sparse params") return sparse_model, params def setup_retro_bottleneck(masks, args, device): """Retrospective sparse + bottleneck: reload, apply masks, add bottlenecks.""" print("\nReloading fresh backbone for retrospective sparse+bottleneck training...") sparse_model = _make_sparse_with_masks(masks, args, device) # Wrap with bottleneck adapters model = RetroBottleneckCLM( sparse_model.backbone, bottleneck_dim=args.bottleneck_dim, ).to(device) params = model.adapter_parameters() n_sparse = sum(p.numel() for p in model.sparse_parameters()) n_bottleneck = sum(p.numel() for p in model.bottleneck_parameters()) n_total = sum(p.numel() for p in params) print(f"Retro-bottleneck: {n_total:,} trainable params") print(f" Sparse: {n_sparse:,}, Bottleneck: {n_bottleneck:,}") return model, params # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): args = parse_args() device = args.device log_dir = Path(args.log_dir) if args.log_dir else Path(__file__).resolve().parent.parent.parent / "logs" if args.output_dir: out_dir = Path(args.output_dir) out_dir.mkdir(parents=True, exist_ok=True) import psutil as _psutil logger = MetricsLogger.__new__(MetricsLogger) logger.slug = "" logger.run_dir = out_dir logger.metrics_path = out_dir / "metrics.jsonl" logger._file = open(logger.metrics_path, "a") logger._proc = _psutil.Process() logger._device = device logger._start_time = time.time() else: logger = MetricsLogger(str(log_dir), run_prefix=f"rosa-{args.mode}", device=device) out_dir = logger.run_dir ckpt_dir = out_dir / "checkpoints" ckpt_dir.mkdir(exist_ok=True) print(f"Mode: {args.mode}") print(f"Device: {device}") print(f"Output: {out_dir}") # Write config record logger.log_config( run_type="rosa", mode=args.mode, checkpoint=str(args.checkpoint), pgn=str(args.pgn), epochs=args.epochs, batch_size=args.batch_size, lr=args.lr, weight_decay=args.weight_decay, patience=args.patience, warmup_frac=args.warmup_frac, max_grad_norm=args.max_grad_norm, lora_rank=args.lora_rank, lora_alpha=args.lora_alpha if args.lora_alpha is not None else args.lora_rank, lora_targets=args.lora_targets, lora_ffn=args.lora_ffn, density=args.density, warmup_steps=args.warmup_steps, mask_samples=args.mask_samples, grad_alpha=args.grad_alpha, restart_lora=args.restart_lora, bottleneck_dim=args.bottleneck_dim if args.mode == "retro-bottleneck" else None, ) # ----------------------------------------------------------------------- # Prepare data # ----------------------------------------------------------------------- print(f"\nPreparing Lichess data: {args.pgn}") data = prepare_lichess_dataset( args.pgn, max_ply=255, max_games=args.max_games, min_ply=args.min_ply, ) n_total_games = data["n_games"] n_val = min(args.val_games, n_total_games // 5) n_train = n_total_games - n_val print(f" Train: {n_train} games, Val: {n_val} games") train_ds = LichessDataset(data, start=0, end=n_train).share_memory() val_ds = LichessDataset(data, start=n_train, end=n_total_games) vocab_size = CLMConfig().vocab_size # 4278 max_ply = 255 collate = LegalMaskCollate(seq_len=max_ply + 1, vocab_size=vocab_size) n_workers = args.num_workers train_loader = DataLoader( train_ds, batch_size=args.batch_size, shuffle=True, num_workers=n_workers, pin_memory=True, persistent_workers=n_workers > 0, collate_fn=collate, multiprocessing_context='spawn' if n_workers > 0 else None, ) val_loader = DataLoader( val_ds, batch_size=args.batch_size, shuffle=False, num_workers=0, pin_memory=True, ) mask_builder = LegalMaskBuilder( args.batch_size, max_ply=255, vocab_size=vocab_size, device=device, ) # GPU config (don't compile yet -- save that for Phase 3) from pawn import model as model_module gpu_cfg = configure_gpu( device, no_compile=True, no_amp=args.no_amp, sdpa_math=args.sdpa_math, ) use_amp = gpu_cfg["use_amp"] # Precompute val legal indices val_legal_indices = [] for batch in val_loader: move_ids = batch["move_ids"] if isinstance(move_ids, torch.Tensor): move_ids = move_ids.numpy() game_lengths = np.asarray(batch["game_length"], dtype=np.int16) indices = compute_legal_indices( move_ids, game_lengths, mask_builder.T, vocab_size, ) val_legal_indices.append(torch.from_numpy(indices).pin_memory()) print(f" Precomputed legal masks for {len(val_legal_indices)} val batches") # ----------------------------------------------------------------------- # Phase 1: LoRA warm-up # ----------------------------------------------------------------------- print(f"\nLoading backbone: {args.checkpoint}") backbone = load_backbone(args.checkpoint, device) warmup_model = RoSACLM( backbone, rank=args.lora_rank, alpha=args.lora_alpha, attn_targets=args.lora_targets, adapt_ffn=args.lora_ffn, lora_enabled=True, sparse_enabled=False, ).to(device) run_warmup(warmup_model, train_loader, mask_builder, args, device, use_amp) # ----------------------------------------------------------------------- # Phase 2: Mask generation # ----------------------------------------------------------------------- masks = run_mask_generation( warmup_model, train_loader, mask_builder, args, device, use_amp, ) # Save warm-up LoRA weights for posterity print("\nSaving warm-up LoRA weights...") from pawn.checkpoint import save_adapter_checkpoint # used here and below save_adapter_checkpoint( ckpt_dir / "warmup", warmup_model.adapter_state_dict(), config=vars(args), epoch=-1, step=args.warmup_steps, val_metrics=None, ) print(f" Saved to {ckpt_dir / 'warmup'}") # ----------------------------------------------------------------------- # Phase 3: Mode-dependent training # ----------------------------------------------------------------------- # Re-enable compile for Phase 3 gpu_cfg = configure_gpu( device, no_compile=args.no_compile, no_amp=args.no_amp, sdpa_math=args.sdpa_math, ) if args.mode == "rosa": model, adapter_params = setup_rosa(warmup_model, masks, args) weight_report_fn = model.adapter_weight_report else: # Retrospective modes: free warm-up model, reload backbone del warmup_model gc.collect() if device != "cpu": torch.cuda.empty_cache() if args.mode == "retro-sparse": model, adapter_params = setup_retro_sparse(masks, args, device) weight_report_fn = model.sparse_weight_report else: # retro-bottleneck model, adapter_params = setup_retro_bottleneck(masks, args, device) weight_report_fn = model.adapter_weight_report best_val_loss = train_loop( model, adapter_params, train_loader, val_loader, mask_builder, val_legal_indices, logger, args, device, use_amp, gpu_cfg, weight_report_fn, ) logger.close() print(f"\nDone. Best val_loss={best_val_loss:.4f}") print(f"Checkpoints saved to {out_dir}") if __name__ == "__main__": main()