MR-JEPA / launch_ablations.py
JorgeAV's picture
add: launch_ablations.py — complete ablation matrix with CLI commands for all 12 experiments
212d421 verified
#!/usr/bin/env python3
"""
MR-JEPA Ablation Launcher
Generates hf_jobs commands for all 11 ablation experiments.
Run hybrid_main FIRST, verify it works, then launch the rest.
Usage:
python launch_ablations.py --dry-run # Print commands only
python launch_ablations.py --launch # Actually submit jobs
Ablation matrix (11 experiments + 1 baseline = 12 total):
| Experiment | CLI flags | Tests |
|------------------|------------------------------------|------------------------------|
| hybrid_main | (default) | Full model baseline |
| no_jepa | --no_jepa | JEPA objective value |
| no_rollout | --no_rollout | Iterative refinement value |
| no_gate | --no_evidence_gate | Evidence gating value |
| K1 | --K 1 | Optimal rollout depth |
| K5 | --K 5 | Optimal rollout depth |
| K7 | --K 7 | Optimal rollout depth |
| mse_loss | --loss_fn mse | Loss function comparison |
| cosine_loss | --loss_fn cosine | Loss function comparison |
| no_sigreg | --no_sigreg | Anti-collapse necessity |
| vicreg_only | --no_sigreg --use_vicreg | SIGReg vs VICReg |
| dinov2_ablation | --backbone dinov2 | DINOv3 vs DINOv2 |
| purist | --purist | Pure JEPA contribution |
Estimated total: ~39 GPU-hours on A100 ($156 at $4/h)
"""
ABLATIONS = {
"no_jepa": "--run_name no_jepa --no_jepa --epochs 15",
"no_rollout": "--run_name no_rollout --no_rollout --epochs 15",
"no_gate": "--run_name no_gate --no_evidence_gate --epochs 15",
"K1": "--run_name K1 --K 1 --epochs 15",
"K5": "--run_name K5 --K 5 --epochs 15",
"K7": "--run_name K7 --K 7 --epochs 15",
"mse_loss": "--run_name mse_loss --loss_fn mse --epochs 15",
"cosine_loss": "--run_name cosine_loss --loss_fn cosine --epochs 15",
"no_sigreg": "--run_name no_sigreg --no_sigreg --epochs 15",
"vicreg_only": "--run_name vicreg_only --no_sigreg --use_vicreg --epochs 15",
"dinov2_ablation": "--run_name dinov2_ablation --backbone dinov2 --epochs 15",
"purist": "--run_name purist --purist --epochs 15",
}
if __name__ == "__main__":
for name, args in ABLATIONS.items():
print(f"# {name}")
print(f"python train_mrjepa.py {args}")
print()