""" Compare Bertint V3 ablation results across Bertose checkpoints. Loads test_results.json from each ablation run and produces a comparison table for the paper. Usage: python compare_ablation.py --base_dir checkpoints/ python compare_ablation.py --base_dir /work/ratul1/supantha/Bertint/checkpoints/ """ import json import argparse from pathlib import Path # Ablation chain (order matters for table) ABLATION_NAMES = [ ("seq_only", "1. Seq-only"), ("multimodal", "2. +Multimodal"), ("topology", "3. +Topology"), ("ipa", "4. +IPA"), ("contrastive", "5. +Contrastive (Bertose v6)"), ] # Metrics to display METRICS = [ ("pearson", "Pearson", ".4f"), ("spearman", "Spearman", ".4f"), ("r2", "R²", ".4f"), ("mse", "MSE", ".6f"), ("frac_in_interval", "InInterval", ".3f"), ("n_samples", "N", "d"), ] def load_results(base_dir: Path) -> dict: """Load test_results.json from each ablation directory.""" results = {} for short_name, display_name in ABLATION_NAMES: result_path = base_dir / f"ablation_{short_name}" / "test_results.json" if result_path.exists(): with open(result_path) as f: results[short_name] = json.load(f) print(f" ✅ Loaded: {result_path}") else: print(f" ⏳ Not found: {result_path}") return results def print_table(results: dict) -> None: """Print formatted comparison table.""" # Header header = f"{'Checkpoint':<32}" for _, label, _ in METRICS: header += f" {label:>10}" print() print("=" * len(header)) print("BERTOSE ABLATION — DOWNSTREAM BINDING PREDICTION (Bertint V3)") print("=" * len(header)) print(header) print("-" * len(header)) # Rows for short_name, display_name in ABLATION_NAMES: if short_name not in results: row = f"{display_name:<32} {'(not available)':>10}" print(row) continue data = results[short_name] row = f"{display_name:<32}" for key, _, fmt in METRICS: val = data.get(key, float('nan')) row += f" {val:>{10}{fmt}}" print(row) print("-" * len(header)) # Delta from baseline if "seq_only" in results and "contrastive" in results: base = results["seq_only"] best = results["contrastive"] print(f"\n{'Δ (best - baseline)':<32}", end="") for key, _, fmt in METRICS: if key == "n_samples": print(f" {'—':>10}", end="") else: delta = best.get(key, 0) - base.get(key, 0) sign = "+" if delta >= 0 else "" print(f" {sign}{delta:>{9}{fmt}}", end="") print() print() def save_latex(results: dict, output_path: Path) -> None: """Save LaTeX table for paper.""" lines = [ r"\begin{table}[t]", r"\centering", r"\caption{Ablation study: Effect of Bertose pretraining stages on downstream glycan-protein binding prediction (Bertint V3, lectin-cold test set).}", r"\label{tab:ablation}", r"\begin{tabular}{l" + "c" * len(METRICS) + "}", r"\toprule", ] # Header header = "Checkpoint" for _, label, _ in METRICS: if label == "N": continue header += f" & {label}" header += r" \\" lines.append(header) lines.append(r"\midrule") # Rows for short_name, display_name in ABLATION_NAMES: if short_name not in results: continue data = results[short_name] row = display_name.replace("+", r"\,+\,") for key, label, fmt in METRICS: if label == "N": continue val = data.get(key, float('nan')) row += f" & {val:{fmt}}" row += r" \\" lines.append(row) lines.extend([ r"\bottomrule", r"\end{tabular}", r"\end{table}", ]) with open(output_path, 'w') as f: f.write("\n".join(lines)) print(f"LaTeX table saved to {output_path}") def main(): parser = argparse.ArgumentParser( description='Compare Bertint V3 ablation results' ) parser.add_argument( '--base_dir', type=str, default='checkpoints/', help='Base directory containing ablation_* subdirs' ) parser.add_argument( '--latex', type=str, default=None, help='Optional: save LaTeX table to file' ) args = parser.parse_args() base_dir = Path(args.base_dir) print(f"Loading results from {base_dir}") results = load_results(base_dir) if not results: print("\n❌ No results found. Have you run the ablation training?") return print_table(results) if args.latex: save_latex(results, Path(args.latex)) if __name__ == '__main__': main()