| """ |
| Compare Bertint V3 ablation results across Bertose checkpoints. |
| |
| Loads test_results.json from each ablation run and produces a |
| comparison table for the paper. |
| |
| Usage: |
| python compare_ablation.py --base_dir checkpoints/ |
| python compare_ablation.py --base_dir /work/ratul1/supantha/Bertint/checkpoints/ |
| """ |
|
|
| import json |
| import argparse |
| from pathlib import Path |
|
|
|
|
| |
| ABLATION_NAMES = [ |
| ("seq_only", "1. Seq-only"), |
| ("multimodal", "2. +Multimodal"), |
| ("topology", "3. +Topology"), |
| ("ipa", "4. +IPA"), |
| ("contrastive", "5. +Contrastive (Bertose v6)"), |
| ] |
|
|
| |
| METRICS = [ |
| ("pearson", "Pearson", ".4f"), |
| ("spearman", "Spearman", ".4f"), |
| ("r2", "R²", ".4f"), |
| ("mse", "MSE", ".6f"), |
| ("frac_in_interval", "InInterval", ".3f"), |
| ("n_samples", "N", "d"), |
| ] |
|
|
|
|
| def load_results(base_dir: Path) -> dict: |
| """Load test_results.json from each ablation directory.""" |
| results = {} |
| for short_name, display_name in ABLATION_NAMES: |
| result_path = base_dir / f"ablation_{short_name}" / "test_results.json" |
| if result_path.exists(): |
| with open(result_path) as f: |
| results[short_name] = json.load(f) |
| print(f" ✅ Loaded: {result_path}") |
| else: |
| print(f" ⏳ Not found: {result_path}") |
| return results |
|
|
|
|
| def print_table(results: dict) -> None: |
| """Print formatted comparison table.""" |
| |
| header = f"{'Checkpoint':<32}" |
| for _, label, _ in METRICS: |
| header += f" {label:>10}" |
| print() |
| print("=" * len(header)) |
| print("BERTOSE ABLATION — DOWNSTREAM BINDING PREDICTION (Bertint V3)") |
| print("=" * len(header)) |
| print(header) |
| print("-" * len(header)) |
|
|
| |
| for short_name, display_name in ABLATION_NAMES: |
| if short_name not in results: |
| row = f"{display_name:<32} {'(not available)':>10}" |
| print(row) |
| continue |
|
|
| data = results[short_name] |
| row = f"{display_name:<32}" |
| for key, _, fmt in METRICS: |
| val = data.get(key, float('nan')) |
| row += f" {val:>{10}{fmt}}" |
| print(row) |
|
|
| print("-" * len(header)) |
|
|
| |
| if "seq_only" in results and "contrastive" in results: |
| base = results["seq_only"] |
| best = results["contrastive"] |
| print(f"\n{'Δ (best - baseline)':<32}", end="") |
| for key, _, fmt in METRICS: |
| if key == "n_samples": |
| print(f" {'—':>10}", end="") |
| else: |
| delta = best.get(key, 0) - base.get(key, 0) |
| sign = "+" if delta >= 0 else "" |
| print(f" {sign}{delta:>{9}{fmt}}", end="") |
| print() |
|
|
| print() |
|
|
|
|
| def save_latex(results: dict, output_path: Path) -> None: |
| """Save LaTeX table for paper.""" |
| lines = [ |
| r"\begin{table}[t]", |
| r"\centering", |
| r"\caption{Ablation study: Effect of Bertose pretraining stages on downstream glycan-protein binding prediction (Bertint V3, lectin-cold test set).}", |
| r"\label{tab:ablation}", |
| r"\begin{tabular}{l" + "c" * len(METRICS) + "}", |
| r"\toprule", |
| ] |
|
|
| |
| header = "Checkpoint" |
| for _, label, _ in METRICS: |
| if label == "N": |
| continue |
| header += f" & {label}" |
| header += r" \\" |
| lines.append(header) |
| lines.append(r"\midrule") |
|
|
| |
| for short_name, display_name in ABLATION_NAMES: |
| if short_name not in results: |
| continue |
| data = results[short_name] |
| row = display_name.replace("+", r"\,+\,") |
| for key, label, fmt in METRICS: |
| if label == "N": |
| continue |
| val = data.get(key, float('nan')) |
| row += f" & {val:{fmt}}" |
| row += r" \\" |
| lines.append(row) |
|
|
| lines.extend([ |
| r"\bottomrule", |
| r"\end{tabular}", |
| r"\end{table}", |
| ]) |
|
|
| with open(output_path, 'w') as f: |
| f.write("\n".join(lines)) |
| print(f"LaTeX table saved to {output_path}") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description='Compare Bertint V3 ablation results' |
| ) |
| parser.add_argument( |
| '--base_dir', type=str, |
| default='checkpoints/', |
| help='Base directory containing ablation_* subdirs' |
| ) |
| parser.add_argument( |
| '--latex', type=str, default=None, |
| help='Optional: save LaTeX table to file' |
| ) |
| args = parser.parse_args() |
|
|
| base_dir = Path(args.base_dir) |
| print(f"Loading results from {base_dir}") |
|
|
| results = load_results(base_dir) |
|
|
| if not results: |
| print("\n❌ No results found. Have you run the ablation training?") |
| return |
|
|
| print_table(results) |
|
|
| if args.latex: |
| save_latex(results, Path(args.latex)) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|