File size: 2,807 Bytes
212d421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#!/usr/bin/env python3
"""
MR-JEPA Ablation Launcher

Generates hf_jobs commands for all 11 ablation experiments.
Run hybrid_main FIRST, verify it works, then launch the rest.

Usage:
    python launch_ablations.py --dry-run     # Print commands only
    python launch_ablations.py --launch      # Actually submit jobs

Ablation matrix (11 experiments + 1 baseline = 12 total):

| Experiment       | CLI flags                          | Tests                        |
|------------------|------------------------------------|------------------------------|
| hybrid_main      | (default)                          | Full model baseline          |
| no_jepa          | --no_jepa                          | JEPA objective value         |
| no_rollout       | --no_rollout                       | Iterative refinement value   |
| no_gate          | --no_evidence_gate                 | Evidence gating value        |
| K1               | --K 1                              | Optimal rollout depth        |
| K5               | --K 5                              | Optimal rollout depth        |
| K7               | --K 7                              | Optimal rollout depth        |
| mse_loss         | --loss_fn mse                      | Loss function comparison     |
| cosine_loss      | --loss_fn cosine                   | Loss function comparison     |
| no_sigreg        | --no_sigreg                        | Anti-collapse necessity      |
| vicreg_only      | --no_sigreg --use_vicreg           | SIGReg vs VICReg             |
| dinov2_ablation  | --backbone dinov2                  | DINOv3 vs DINOv2             |
| purist           | --purist                           | Pure JEPA contribution       |

Estimated total: ~39 GPU-hours on A100 ($156 at $4/h)
"""

ABLATIONS = {
    "no_jepa":          "--run_name no_jepa --no_jepa --epochs 15",
    "no_rollout":       "--run_name no_rollout --no_rollout --epochs 15",
    "no_gate":          "--run_name no_gate --no_evidence_gate --epochs 15",
    "K1":               "--run_name K1 --K 1 --epochs 15",
    "K5":               "--run_name K5 --K 5 --epochs 15",
    "K7":               "--run_name K7 --K 7 --epochs 15",
    "mse_loss":         "--run_name mse_loss --loss_fn mse --epochs 15",
    "cosine_loss":      "--run_name cosine_loss --loss_fn cosine --epochs 15",
    "no_sigreg":        "--run_name no_sigreg --no_sigreg --epochs 15",
    "vicreg_only":      "--run_name vicreg_only --no_sigreg --use_vicreg --epochs 15",
    "dinov2_ablation":  "--run_name dinov2_ablation --backbone dinov2 --epochs 15",
    "purist":           "--run_name purist --purist --epochs 15",
}

if __name__ == "__main__":
    for name, args in ABLATIONS.items():
        print(f"# {name}")
        print(f"python train_mrjepa.py {args}")
        print()