JorgeAV commited on
Commit
b198bd2
·
verified ·
1 Parent(s): ad6ce96

fix: ablation.py — wire all disable flags, add missing experiments (dinov2, mse/cosine loss, no_sigreg, vicreg_only), add run()/load_results() methods, generate CLI commands

Browse files
Files changed (1) hide show
  1. mr_jepa/utils/ablation.py +317 -83
mr_jepa/utils/ablation.py CHANGED
@@ -3,22 +3,32 @@ Ablation Study Runner for MR-JEPA.
3
 
4
  Supports systematic ablation experiments to validate the paper's contributions:
5
 
6
- 1. Full MR-JEPA vs. No JEPA (remove JEPA loss, train with task loss only)
7
- 2. Full MR-JEPA vs. No Rollout (use z₀ directly, K=0)
8
- 3. Full MR-JEPA vs. No Evidence Gate (remove gating, always use full evidence)
9
- 4. K=1 vs. K=3 vs. K=5 (rollout depth ablation)
10
- 5. With vs. Without enriched evidence (Phase 3 ablation)
11
- 6. Hybrid vs. Purist branch comparison
 
 
 
 
 
 
 
12
  """
13
 
14
  import copy
15
  import json
 
16
  import logging
17
  from typing import Dict, List, Any, Optional
18
  from dataclasses import dataclass, field
19
  from pathlib import Path
20
 
21
- from ..configs.model_config import MRJEPAConfig, get_hybrid_config, get_purist_config
 
 
22
 
23
  logger = logging.getLogger(__name__)
24
 
@@ -28,155 +38,379 @@ class AblationConfig:
28
  """Configuration for a single ablation experiment."""
29
  name: str
30
  description: str
31
- modifications: Dict[str, Any] = field(default_factory=dict)
32
- # What to change from the base config
 
33
  disable_jepa: bool = False
34
  disable_rollout: bool = False
35
  disable_evidence_gate: bool = False
 
 
36
  override_K: Optional[int] = None
 
 
 
37
 
38
 
39
- # Predefined ablation experiments
 
 
40
  ABLATION_EXPERIMENTS = {
41
- "full_model": AblationConfig(
42
- name="full_model",
43
- description="Complete MR-JEPA (baseline)",
 
 
44
  ),
 
 
45
  "no_jepa": AblationConfig(
46
  name="no_jepa",
47
- description="Without JEPA objective (task loss only)",
 
48
  disable_jepa=True,
49
  ),
50
  "no_rollout": AblationConfig(
51
  name="no_rollout",
52
- description="Without latent rollout (z₀ only, K=0)",
 
53
  disable_rollout=True,
 
 
54
  ),
55
- "no_evidence_gate": AblationConfig(
56
- name="no_evidence_gate",
57
- description="Without evidence gating",
 
58
  disable_evidence_gate=True,
59
  ),
 
 
60
  "K1": AblationConfig(
61
  name="K1",
62
- description="Rollout depth K=1",
 
63
  override_K=1,
64
  ),
65
  "K3": AblationConfig(
66
  name="K3",
67
- description="Rollout depth K=3 (default)",
 
68
  override_K=3,
69
  ),
70
  "K5": AblationConfig(
71
  name="K5",
72
- description="Rollout depth K=5",
 
73
  override_K=5,
74
  ),
75
  "K7": AblationConfig(
76
  name="K7",
77
- description="Rollout depth K=7 (deep rollout)",
 
78
  override_K=7,
79
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  }
81
 
82
 
83
  class AblationRunner:
84
  """
85
  Systematically run ablation experiments.
86
-
 
 
 
 
87
  Usage:
88
- runner = AblationRunner(base_config, experiments=['full_model', 'no_jepa', 'no_rollout'])
89
- results = runner.run(train_data, eval_data)
90
- runner.report()
 
 
 
 
 
 
 
 
 
 
91
  """
92
-
93
  def __init__(
94
  self,
95
- base_config: Optional[MRJEPAConfig] = None,
96
  experiments: Optional[List[str]] = None,
97
- output_dir: str = "./ablations",
 
 
98
  ):
99
- self.base_config = base_config or get_hybrid_config()
100
- self.experiments = experiments or list(ABLATION_EXPERIMENTS.keys())
 
 
 
 
101
  self.output_dir = Path(output_dir)
102
  self.output_dir.mkdir(parents=True, exist_ok=True)
103
- self.results = {}
104
-
 
 
105
  def _apply_ablation(self, config: MRJEPAConfig, ablation: AblationConfig) -> MRJEPAConfig:
106
  """Apply ablation modifications to a config."""
107
  modified = copy.deepcopy(config)
108
-
 
 
 
 
 
 
109
  if ablation.override_K is not None:
110
  modified.rollout.K = ablation.override_K
111
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  return modified
113
-
114
  def generate_configs(self) -> Dict[str, MRJEPAConfig]:
115
- """Generate configs for all ablation experiments."""
 
116
  configs = {}
117
  for exp_name in self.experiments:
118
  if exp_name not in ABLATION_EXPERIMENTS:
119
- logger.warning(f"Unknown ablation: {exp_name}")
120
  continue
121
-
122
  ablation = ABLATION_EXPERIMENTS[exp_name]
123
- config = self._apply_ablation(self.base_config, ablation)
124
- configs[exp_name] = config
125
-
126
  return configs
127
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  def report(self) -> str:
129
  """Generate a formatted ablation report."""
130
  if not self.results:
131
- return "No results yet."
132
-
133
  lines = [
134
- "=" * 80,
135
  "MR-JEPA Ablation Study Results",
136
- "=" * 80,
137
  "",
138
  ]
139
-
140
- # Header
141
- benchmarks = set()
142
  for exp_results in self.results.values():
143
- benchmarks.update(exp_results.keys())
144
- benchmarks = sorted(benchmarks)
145
-
146
- header = f"{'Experiment':<25}"
147
- for b in benchmarks:
148
- header += f" | {b:<12}"
 
 
 
 
 
149
  lines.append(header)
150
  lines.append("-" * len(header))
151
-
152
- # Results rows
153
- for exp_name, exp_results in self.results.items():
 
 
 
154
  ablation = ABLATION_EXPERIMENTS.get(exp_name)
155
- row = f"{exp_name:<25}"
156
- for b in benchmarks:
157
- if b in exp_results:
158
- val = exp_results[b].get('accuracy',
159
- exp_results[b].get('anls',
160
- exp_results[b].get('vqa_accuracy',
161
- exp_results[b].get('relaxed_accuracy', 0))))
162
- row += f" | {val:>10.1f}%"
 
 
 
163
  else:
164
  row += f" | {'N/A':>10}"
165
  lines.append(row)
166
-
 
 
167
  lines.append("")
168
- lines.append("Key findings:")
169
-
170
- # Auto-detect key findings
171
- if 'full_model' in self.results and 'no_jepa' in self.results:
172
- lines.append("- JEPA vs No-JEPA: Compare 'full_model' and 'no_jepa' rows")
173
- if 'full_model' in self.results and 'no_rollout' in self.results:
174
- lines.append("- Rollout vs No-Rollout: Compare 'full_model' and 'no_rollout' rows")
175
-
176
- report = "\n".join(lines)
177
-
178
- # Save to file
179
- with open(self.output_dir / "ablation_report.txt", "w") as f:
180
- f.write(report)
181
-
182
- return report
 
 
 
 
 
 
 
 
 
3
 
4
  Supports systematic ablation experiments to validate the paper's contributions:
5
 
6
+ 1. hybrid_main — Full MR-JEPA baseline (DINOv3-L, K=3, SmoothL1+VICReg)
7
+ 2. no_jepa — Remove JEPA loss, train with task loss only
8
+ 3. no_rollout — Use z₀ directly (K=0), keep task loss only
9
+ 4. no_gate — Remove evidence gating, always use full evidence
10
+ 5. K1 / K5 / K7 — Rollout depth ablation
11
+ 6. dinov2_ablation — DINOv2-L/14 backbone instead of DINOv3-L/16
12
+ 7. purist — DINOv3-B, K=5, Cosine+SIGReg, no enriched evidence
13
+ 8. mse_loss — MSE (L2) JEPA loss instead of SmoothL1
14
+ 9. cosine_loss — Cosine similarity JEPA loss
15
+ 10. no_sigreg — Disable SIGReg anti-collapse regularization
16
+ 11. vicreg_only — VICReg regularization without SIGReg
17
+
18
+ Each AblationConfig maps 1:1 to CLI flags in train_mrjepa.py.
19
  """
20
 
21
  import copy
22
  import json
23
+ import subprocess
24
  import logging
25
  from typing import Dict, List, Any, Optional
26
  from dataclasses import dataclass, field
27
  from pathlib import Path
28
 
29
+ from ..configs.model_config import (
30
+ MRJEPAConfig, get_hybrid_config, get_purist_config, get_dinov2_ablation_config,
31
+ )
32
 
33
  logger = logging.getLogger(__name__)
34
 
 
38
  """Configuration for a single ablation experiment."""
39
  name: str
40
  description: str
41
+ # CLI flags that map to train_mrjepa.py arguments
42
+ cli_flags: Dict[str, Any] = field(default_factory=dict)
43
+ # Config modifications for the library-based runner
44
  disable_jepa: bool = False
45
  disable_rollout: bool = False
46
  disable_evidence_gate: bool = False
47
+ disable_sigreg: bool = False
48
+ enable_vicreg: bool = False
49
  override_K: Optional[int] = None
50
+ override_loss_fn: Optional[str] = None
51
+ override_backbone: Optional[str] = None
52
+ use_purist: bool = False
53
 
54
 
55
+ # ──────────────────────────────────────────────────────────
56
+ # Complete ablation experiment registry
57
+ # ──────────────────────────────────────────────────────────
58
  ABLATION_EXPERIMENTS = {
59
+ # ── Baseline ──
60
+ "hybrid_main": AblationConfig(
61
+ name="hybrid_main",
62
+ description="Complete MR-JEPA (DINOv3-L, K=3, SmoothL1+VICReg)",
63
+ cli_flags={"--run_name": "hybrid_main"},
64
  ),
65
+
66
+ # ── Core contribution ablations ──
67
  "no_jepa": AblationConfig(
68
  name="no_jepa",
69
+ description="Without JEPA objective task loss only. Tests whether JEPA trajectory supervision adds value.",
70
+ cli_flags={"--run_name": "no_jepa", "--no_jepa": True},
71
  disable_jepa=True,
72
  ),
73
  "no_rollout": AblationConfig(
74
  name="no_rollout",
75
+ description="Without latent rollout z₀ directly to answer head (K=0). Tests whether iterative refinement adds value.",
76
+ cli_flags={"--run_name": "no_rollout", "--no_rollout": True},
77
  disable_rollout=True,
78
+ # NOTE: no_rollout also disables JEPA (can't supervise a trajectory that doesn't exist)
79
+ disable_jepa=True,
80
  ),
81
+ "no_gate": AblationConfig(
82
+ name="no_gate",
83
+ description="Without evidence gating — full evidence at every step. Tests whether adaptive evidence flow matters.",
84
+ cli_flags={"--run_name": "no_gate", "--no_evidence_gate": True},
85
  disable_evidence_gate=True,
86
  ),
87
+
88
+ # ── Rollout depth ablations ──
89
  "K1": AblationConfig(
90
  name="K1",
91
+ description="Rollout depth K=1 (shallow reasoning)",
92
+ cli_flags={"--run_name": "K1", "--K": 1},
93
  override_K=1,
94
  ),
95
  "K3": AblationConfig(
96
  name="K3",
97
+ description="Rollout depth K=3 (default, same as hybrid_main)",
98
+ cli_flags={"--run_name": "K3", "--K": 3},
99
  override_K=3,
100
  ),
101
  "K5": AblationConfig(
102
  name="K5",
103
+ description="Rollout depth K=5 (deeper reasoning)",
104
+ cli_flags={"--run_name": "K5", "--K": 5},
105
  override_K=5,
106
  ),
107
  "K7": AblationConfig(
108
  name="K7",
109
+ description="Rollout depth K=7 (deep rollout — diminishing returns expected)",
110
+ cli_flags={"--run_name": "K7", "--K": 7},
111
  override_K=7,
112
  ),
113
+
114
+ # ── Backbone ablation ──
115
+ "dinov2_ablation": AblationConfig(
116
+ name="dinov2_ablation",
117
+ description="DINOv2-L/14 backbone instead of DINOv3-L/16. Isolates DINOv3 contribution.",
118
+ cli_flags={"--run_name": "dinov2_ablation", "--backbone": "dinov2"},
119
+ override_backbone="dinov2",
120
+ ),
121
+
122
+ # ── Loss function ablations ──
123
+ "mse_loss": AblationConfig(
124
+ name="mse_loss",
125
+ description="MSE (L2) JEPA loss instead of SmoothL1. Original I-JEPA loss.",
126
+ cli_flags={"--run_name": "mse_loss", "--loss_fn": "mse"},
127
+ override_loss_fn="mse",
128
+ ),
129
+ "cosine_loss": AblationConfig(
130
+ name="cosine_loss",
131
+ description="Cosine similarity JEPA loss. Used in purist branch.",
132
+ cli_flags={"--run_name": "cosine_loss", "--loss_fn": "cosine"},
133
+ override_loss_fn="cosine",
134
+ ),
135
+
136
+ # ── Regularization ablations ──
137
+ "no_sigreg": AblationConfig(
138
+ name="no_sigreg",
139
+ description="Disable SIGReg anti-collapse. Expect training instability / collapse.",
140
+ cli_flags={"--run_name": "no_sigreg", "--no_sigreg": True},
141
+ disable_sigreg=True,
142
+ ),
143
+ "vicreg_only": AblationConfig(
144
+ name="vicreg_only",
145
+ description="VICReg regularization only (no SIGReg). Alternative anti-collapse.",
146
+ cli_flags={"--run_name": "vicreg_only", "--no_sigreg": True, "--use_vicreg": True},
147
+ disable_sigreg=True,
148
+ enable_vicreg=True,
149
+ ),
150
+
151
+ # ── Branch comparison ──
152
+ "purist": AblationConfig(
153
+ name="purist",
154
+ description="Purist branch: DINOv3-B, K=5, Cosine+SIGReg, no enriched evidence. Isolates JEPA reasoning from perception quality.",
155
+ cli_flags={"--run_name": "purist", "--purist": True},
156
+ use_purist=True,
157
+ ),
158
  }
159
 
160
 
161
  class AblationRunner:
162
  """
163
  Systematically run ablation experiments.
164
+
165
+ Two modes:
166
+ 1. CLI mode: generates shell commands for train_mrjepa.py (for HF Jobs)
167
+ 2. Config mode: generates MRJEPAConfig objects (for library-based runner)
168
+
169
  Usage:
170
+ runner = AblationRunner(experiments=['hybrid_main', 'no_jepa', 'no_rollout'])
171
+
172
+ # Mode 1: Generate CLI commands
173
+ commands = runner.generate_commands()
174
+ for name, cmd in commands.items():
175
+ print(f"{name}: {cmd}")
176
+
177
+ # Mode 2: Generate configs for programmatic use
178
+ configs = runner.generate_configs()
179
+
180
+ # After running, load results and report
181
+ runner.load_results("./outputs/mrjepa")
182
+ print(runner.report())
183
  """
184
+
185
  def __init__(
186
  self,
 
187
  experiments: Optional[List[str]] = None,
188
+ output_dir: str = "./outputs/mrjepa",
189
+ script_path: str = "train_mrjepa.py",
190
+ common_flags: Optional[Dict[str, Any]] = None,
191
  ):
192
+ self.experiments = experiments or [
193
+ "hybrid_main", "no_jepa", "no_rollout", "no_gate",
194
+ "K1", "K5", "K7",
195
+ "dinov2_ablation", "mse_loss", "cosine_loss",
196
+ "no_sigreg", "purist",
197
+ ]
198
  self.output_dir = Path(output_dir)
199
  self.output_dir.mkdir(parents=True, exist_ok=True)
200
+ self.script_path = script_path
201
+ self.common_flags = common_flags or {}
202
+ self.results: Dict[str, Dict[str, Any]] = {}
203
+
204
  def _apply_ablation(self, config: MRJEPAConfig, ablation: AblationConfig) -> MRJEPAConfig:
205
  """Apply ablation modifications to a config."""
206
  modified = copy.deepcopy(config)
207
+
208
+ if ablation.use_purist:
209
+ return get_purist_config()
210
+
211
+ if ablation.override_backbone == "dinov2":
212
+ return get_dinov2_ablation_config()
213
+
214
  if ablation.override_K is not None:
215
  modified.rollout.K = ablation.override_K
216
+
217
+ if ablation.disable_jepa:
218
+ modified.jepa.use_jepa = False
219
+
220
+ if ablation.disable_rollout:
221
+ modified.rollout.K = 0
222
+ modified.jepa.use_jepa = False # No trajectory to supervise
223
+
224
+ if ablation.disable_evidence_gate:
225
+ modified.rollout.use_evidence_gate = False
226
+ modified.rollout.gate_type = "none"
227
+
228
+ if ablation.disable_sigreg:
229
+ modified.jepa.use_sigreg = False
230
+ modified.jepa.sigreg_weight = 0.0
231
+
232
+ if ablation.enable_vicreg:
233
+ modified.jepa.use_vicreg = True
234
+
235
+ if ablation.override_loss_fn is not None:
236
+ modified.jepa.jepa_loss_fn = ablation.override_loss_fn
237
+
238
  return modified
239
+
240
  def generate_configs(self) -> Dict[str, MRJEPAConfig]:
241
+ """Generate MRJEPAConfig objects for all ablation experiments."""
242
+ base_config = get_hybrid_config()
243
  configs = {}
244
  for exp_name in self.experiments:
245
  if exp_name not in ABLATION_EXPERIMENTS:
246
+ logger.warning(f"Unknown ablation: {exp_name}, skipping")
247
  continue
 
248
  ablation = ABLATION_EXPERIMENTS[exp_name]
249
+ configs[exp_name] = self._apply_ablation(base_config, ablation)
 
 
250
  return configs
251
+
252
+ def generate_commands(self) -> Dict[str, str]:
253
+ """Generate CLI commands for train_mrjepa.py for each ablation."""
254
+ commands = {}
255
+ for exp_name in self.experiments:
256
+ if exp_name not in ABLATION_EXPERIMENTS:
257
+ logger.warning(f"Unknown ablation: {exp_name}, skipping")
258
+ continue
259
+
260
+ ablation = ABLATION_EXPERIMENTS[exp_name]
261
+ parts = ["python", self.script_path]
262
+
263
+ # Merge common flags + experiment-specific flags
264
+ all_flags = {**self.common_flags, **ablation.cli_flags}
265
+
266
+ for flag, value in all_flags.items():
267
+ if isinstance(value, bool):
268
+ if value:
269
+ parts.append(flag)
270
+ else:
271
+ parts.append(flag)
272
+ parts.append(str(value))
273
+
274
+ commands[exp_name] = " ".join(parts)
275
+
276
+ return commands
277
+
278
+ def run(
279
+ self,
280
+ mode: str = "cli",
281
+ dry_run: bool = False,
282
+ ) -> Dict[str, Any]:
283
+ """
284
+ Run all ablation experiments.
285
+
286
+ Args:
287
+ mode: "cli" to run via subprocess, "config" for programmatic (not yet implemented)
288
+ dry_run: If True, print commands but don't execute
289
+
290
+ Returns:
291
+ Dict mapping experiment name to run status/result
292
+ """
293
+ if mode == "cli":
294
+ commands = self.generate_commands()
295
+ results = {}
296
+ for exp_name, cmd in commands.items():
297
+ logger.info(f"{'[DRY RUN] ' if dry_run else ''}Running ablation: {exp_name}")
298
+ logger.info(f" Command: {cmd}")
299
+
300
+ if dry_run:
301
+ results[exp_name] = {"status": "dry_run", "command": cmd}
302
+ continue
303
+
304
+ try:
305
+ proc = subprocess.run(
306
+ cmd, shell=True, capture_output=True, text=True, timeout=7200,
307
+ )
308
+ results[exp_name] = {
309
+ "status": "success" if proc.returncode == 0 else "failed",
310
+ "returncode": proc.returncode,
311
+ "stdout_tail": proc.stdout[-2000:] if proc.stdout else "",
312
+ "stderr_tail": proc.stderr[-2000:] if proc.stderr else "",
313
+ }
314
+ if proc.returncode != 0:
315
+ logger.error(f" FAILED (rc={proc.returncode}): {proc.stderr[-500:]}")
316
+ else:
317
+ logger.info(f" SUCCESS")
318
+ except subprocess.TimeoutExpired:
319
+ results[exp_name] = {"status": "timeout"}
320
+ logger.error(f" TIMEOUT")
321
+ except Exception as e:
322
+ results[exp_name] = {"status": "error", "error": str(e)}
323
+ logger.error(f" ERROR: {e}")
324
+
325
+ return results
326
+ else:
327
+ raise NotImplementedError(f"Mode '{mode}' not implemented. Use 'cli'.")
328
+
329
+ def load_results(self, results_dir: Optional[str] = None):
330
+ """Load results JSON files from a directory."""
331
+ rdir = Path(results_dir) if results_dir else self.output_dir
332
+ for exp_name in self.experiments:
333
+ result_file = rdir / f"results_{exp_name}.json"
334
+ if result_file.exists():
335
+ with open(result_file) as f:
336
+ self.results[exp_name] = json.load(f)
337
+ logger.info(f"Loaded results for {exp_name}")
338
+ else:
339
+ logger.warning(f"No results file for {exp_name} at {result_file}")
340
+
341
  def report(self) -> str:
342
  """Generate a formatted ablation report."""
343
  if not self.results:
344
+ return "No results loaded. Call load_results() first."
345
+
346
  lines = [
347
+ "=" * 90,
348
  "MR-JEPA Ablation Study Results",
349
+ "=" * 90,
350
  "",
351
  ]
352
+
353
+ # Collect all metric keys across experiments
354
+ metric_keys = set()
355
  for exp_results in self.results.values():
356
+ metric_keys.update(k for k in exp_results.keys() if k.startswith("best_") or k.endswith("_accuracy"))
357
+ metric_keys = sorted(metric_keys)
358
+
359
+ if not metric_keys:
360
+ metric_keys = ["best_eval_accuracy"]
361
+
362
+ # Header
363
+ header = f"{'Experiment':<22} | {'K':>2} | {'JEPA':>4} | {'Gate':>4} | {'Loss':>9}"
364
+ for mk in metric_keys:
365
+ short = mk.replace("best_eval_", "").replace("best_", "").replace("_accuracy", "_acc")[:12]
366
+ header += f" | {short:>10}"
367
  lines.append(header)
368
  lines.append("-" * len(header))
369
+
370
+ # Rows
371
+ for exp_name in self.experiments:
372
+ if exp_name not in self.results:
373
+ continue
374
+ r = self.results[exp_name]
375
  ablation = ABLATION_EXPERIMENTS.get(exp_name)
376
+
377
+ row = f"{exp_name:<22}"
378
+ row += f" | {r.get('K', '?'):>2}"
379
+ row += f" | {'Y' if r.get('use_jepa', True) else 'N':>4}"
380
+ row += f" | {'Y' if r.get('use_evidence_gate', True) else 'N':>4}"
381
+ row += f" | {r.get('loss_fn', 'smooth_l1'):>9}"
382
+
383
+ for mk in metric_keys:
384
+ val = r.get(mk)
385
+ if val is not None:
386
+ row += f" | {val:>9.1f}%"
387
  else:
388
  row += f" | {'N/A':>10}"
389
  lines.append(row)
390
+
391
+ lines.append("")
392
+ lines.append("=" * 90)
393
  lines.append("")
394
+
395
+ # Auto-generate key findings
396
+ lines.append("Key comparisons:")
397
+ if "hybrid_main" in self.results:
398
+ base_acc = self.results["hybrid_main"].get("best_eval_accuracy", 0)
399
+ for exp_name in ["no_jepa", "no_rollout", "no_gate"]:
400
+ if exp_name in self.results:
401
+ exp_acc = self.results[exp_name].get("best_eval_accuracy", 0)
402
+ delta = exp_acc - base_acc
403
+ lines.append(
404
+ f" {exp_name:>15} vs hybrid_main: {delta:+.1f}% "
405
+ f"({'JEPA helps' if delta < 0 else 'no benefit'})"
406
+ )
407
+
408
+ report_text = "\n".join(lines)
409
+
410
+ # Save
411
+ report_path = self.output_dir / "ablation_report.txt"
412
+ with open(report_path, "w") as f:
413
+ f.write(report_text)
414
+ logger.info(f"Ablation report saved to {report_path}")
415
+
416
+ return report_text