#!/usr/bin/env python3 """ MR-JEPA Ablation Launcher Generates hf_jobs commands for all 11 ablation experiments. Run hybrid_main FIRST, verify it works, then launch the rest. Usage: python launch_ablations.py --dry-run # Print commands only python launch_ablations.py --launch # Actually submit jobs Ablation matrix (11 experiments + 1 baseline = 12 total): | Experiment | CLI flags | Tests | |------------------|------------------------------------|------------------------------| | hybrid_main | (default) | Full model baseline | | no_jepa | --no_jepa | JEPA objective value | | no_rollout | --no_rollout | Iterative refinement value | | no_gate | --no_evidence_gate | Evidence gating value | | K1 | --K 1 | Optimal rollout depth | | K5 | --K 5 | Optimal rollout depth | | K7 | --K 7 | Optimal rollout depth | | mse_loss | --loss_fn mse | Loss function comparison | | cosine_loss | --loss_fn cosine | Loss function comparison | | no_sigreg | --no_sigreg | Anti-collapse necessity | | vicreg_only | --no_sigreg --use_vicreg | SIGReg vs VICReg | | dinov2_ablation | --backbone dinov2 | DINOv3 vs DINOv2 | | purist | --purist | Pure JEPA contribution | Estimated total: ~39 GPU-hours on A100 ($156 at $4/h) """ ABLATIONS = { "no_jepa": "--run_name no_jepa --no_jepa --epochs 15", "no_rollout": "--run_name no_rollout --no_rollout --epochs 15", "no_gate": "--run_name no_gate --no_evidence_gate --epochs 15", "K1": "--run_name K1 --K 1 --epochs 15", "K5": "--run_name K5 --K 5 --epochs 15", "K7": "--run_name K7 --K 7 --epochs 15", "mse_loss": "--run_name mse_loss --loss_fn mse --epochs 15", "cosine_loss": "--run_name cosine_loss --loss_fn cosine --epochs 15", "no_sigreg": "--run_name no_sigreg --no_sigreg --epochs 15", "vicreg_only": "--run_name vicreg_only --no_sigreg --use_vicreg --epochs 15", "dinov2_ablation": "--run_name dinov2_ablation --backbone dinov2 --epochs 15", "purist": "--run_name purist --purist --epochs 15", } if __name__ == "__main__": for name, args in ABLATIONS.items(): print(f"# {name}") print(f"python train_mrjepa.py {args}") print()