File size: 4,885 Bytes
1d6f391 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 | """
Compare Bertint V3 ablation results across Bertose checkpoints.
Loads test_results.json from each ablation run and produces a
comparison table for the paper.
Usage:
python compare_ablation.py --base_dir checkpoints/
python compare_ablation.py --base_dir /work/ratul1/supantha/Bertint/checkpoints/
"""
import json
import argparse
from pathlib import Path
# Ablation chain (order matters for table)
ABLATION_NAMES = [
("seq_only", "1. Seq-only"),
("multimodal", "2. +Multimodal"),
("topology", "3. +Topology"),
("ipa", "4. +IPA"),
("contrastive", "5. +Contrastive (Bertose v6)"),
]
# Metrics to display
METRICS = [
("pearson", "Pearson", ".4f"),
("spearman", "Spearman", ".4f"),
("r2", "R²", ".4f"),
("mse", "MSE", ".6f"),
("frac_in_interval", "InInterval", ".3f"),
("n_samples", "N", "d"),
]
def load_results(base_dir: Path) -> dict:
"""Load test_results.json from each ablation directory."""
results = {}
for short_name, display_name in ABLATION_NAMES:
result_path = base_dir / f"ablation_{short_name}" / "test_results.json"
if result_path.exists():
with open(result_path) as f:
results[short_name] = json.load(f)
print(f" ✅ Loaded: {result_path}")
else:
print(f" ⏳ Not found: {result_path}")
return results
def print_table(results: dict) -> None:
"""Print formatted comparison table."""
# Header
header = f"{'Checkpoint':<32}"
for _, label, _ in METRICS:
header += f" {label:>10}"
print()
print("=" * len(header))
print("BERTOSE ABLATION — DOWNSTREAM BINDING PREDICTION (Bertint V3)")
print("=" * len(header))
print(header)
print("-" * len(header))
# Rows
for short_name, display_name in ABLATION_NAMES:
if short_name not in results:
row = f"{display_name:<32} {'(not available)':>10}"
print(row)
continue
data = results[short_name]
row = f"{display_name:<32}"
for key, _, fmt in METRICS:
val = data.get(key, float('nan'))
row += f" {val:>{10}{fmt}}"
print(row)
print("-" * len(header))
# Delta from baseline
if "seq_only" in results and "contrastive" in results:
base = results["seq_only"]
best = results["contrastive"]
print(f"\n{'Δ (best - baseline)':<32}", end="")
for key, _, fmt in METRICS:
if key == "n_samples":
print(f" {'—':>10}", end="")
else:
delta = best.get(key, 0) - base.get(key, 0)
sign = "+" if delta >= 0 else ""
print(f" {sign}{delta:>{9}{fmt}}", end="")
print()
print()
def save_latex(results: dict, output_path: Path) -> None:
"""Save LaTeX table for paper."""
lines = [
r"\begin{table}[t]",
r"\centering",
r"\caption{Ablation study: Effect of Bertose pretraining stages on downstream glycan-protein binding prediction (Bertint V3, lectin-cold test set).}",
r"\label{tab:ablation}",
r"\begin{tabular}{l" + "c" * len(METRICS) + "}",
r"\toprule",
]
# Header
header = "Checkpoint"
for _, label, _ in METRICS:
if label == "N":
continue
header += f" & {label}"
header += r" \\"
lines.append(header)
lines.append(r"\midrule")
# Rows
for short_name, display_name in ABLATION_NAMES:
if short_name not in results:
continue
data = results[short_name]
row = display_name.replace("+", r"\,+\,")
for key, label, fmt in METRICS:
if label == "N":
continue
val = data.get(key, float('nan'))
row += f" & {val:{fmt}}"
row += r" \\"
lines.append(row)
lines.extend([
r"\bottomrule",
r"\end{tabular}",
r"\end{table}",
])
with open(output_path, 'w') as f:
f.write("\n".join(lines))
print(f"LaTeX table saved to {output_path}")
def main():
parser = argparse.ArgumentParser(
description='Compare Bertint V3 ablation results'
)
parser.add_argument(
'--base_dir', type=str,
default='checkpoints/',
help='Base directory containing ablation_* subdirs'
)
parser.add_argument(
'--latex', type=str, default=None,
help='Optional: save LaTeX table to file'
)
args = parser.parse_args()
base_dir = Path(args.base_dir)
print(f"Loading results from {base_dir}")
results = load_results(base_dir)
if not results:
print("\n❌ No results found. Have you run the ablation training?")
return
print_table(results)
if args.latex:
save_latex(results, Path(args.latex))
if __name__ == '__main__':
main()
|