supanthadey1's picture
Add BERTose and AFFINose training code release
1d6f391 verified
Raw
History Blame Contribute Delete
4.89 kB
"""
Compare Bertint V3 ablation results across Bertose checkpoints.
Loads test_results.json from each ablation run and produces a
comparison table for the paper.
Usage:
python compare_ablation.py --base_dir checkpoints/
python compare_ablation.py --base_dir /work/ratul1/supantha/Bertint/checkpoints/
"""
import json
import argparse
from pathlib import Path
# Ablation chain (order matters for table)
ABLATION_NAMES = [
("seq_only", "1. Seq-only"),
("multimodal", "2. +Multimodal"),
("topology", "3. +Topology"),
("ipa", "4. +IPA"),
("contrastive", "5. +Contrastive (Bertose v6)"),
]
# Metrics to display
METRICS = [
("pearson", "Pearson", ".4f"),
("spearman", "Spearman", ".4f"),
("r2", "R²", ".4f"),
("mse", "MSE", ".6f"),
("frac_in_interval", "InInterval", ".3f"),
("n_samples", "N", "d"),
]
def load_results(base_dir: Path) -> dict:
"""Load test_results.json from each ablation directory."""
results = {}
for short_name, display_name in ABLATION_NAMES:
result_path = base_dir / f"ablation_{short_name}" / "test_results.json"
if result_path.exists():
with open(result_path) as f:
results[short_name] = json.load(f)
print(f" ✅ Loaded: {result_path}")
else:
print(f" ⏳ Not found: {result_path}")
return results
def print_table(results: dict) -> None:
"""Print formatted comparison table."""
# Header
header = f"{'Checkpoint':<32}"
for _, label, _ in METRICS:
header += f" {label:>10}"
print()
print("=" * len(header))
print("BERTOSE ABLATION — DOWNSTREAM BINDING PREDICTION (Bertint V3)")
print("=" * len(header))
print(header)
print("-" * len(header))
# Rows
for short_name, display_name in ABLATION_NAMES:
if short_name not in results:
row = f"{display_name:<32} {'(not available)':>10}"
print(row)
continue
data = results[short_name]
row = f"{display_name:<32}"
for key, _, fmt in METRICS:
val = data.get(key, float('nan'))
row += f" {val:>{10}{fmt}}"
print(row)
print("-" * len(header))
# Delta from baseline
if "seq_only" in results and "contrastive" in results:
base = results["seq_only"]
best = results["contrastive"]
print(f"\n{'Δ (best - baseline)':<32}", end="")
for key, _, fmt in METRICS:
if key == "n_samples":
print(f" {'—':>10}", end="")
else:
delta = best.get(key, 0) - base.get(key, 0)
sign = "+" if delta >= 0 else ""
print(f" {sign}{delta:>{9}{fmt}}", end="")
print()
print()
def save_latex(results: dict, output_path: Path) -> None:
"""Save LaTeX table for paper."""
lines = [
r"\begin{table}[t]",
r"\centering",
r"\caption{Ablation study: Effect of Bertose pretraining stages on downstream glycan-protein binding prediction (Bertint V3, lectin-cold test set).}",
r"\label{tab:ablation}",
r"\begin{tabular}{l" + "c" * len(METRICS) + "}",
r"\toprule",
]
# Header
header = "Checkpoint"
for _, label, _ in METRICS:
if label == "N":
continue
header += f" & {label}"
header += r" \\"
lines.append(header)
lines.append(r"\midrule")
# Rows
for short_name, display_name in ABLATION_NAMES:
if short_name not in results:
continue
data = results[short_name]
row = display_name.replace("+", r"\,+\,")
for key, label, fmt in METRICS:
if label == "N":
continue
val = data.get(key, float('nan'))
row += f" & {val:{fmt}}"
row += r" \\"
lines.append(row)
lines.extend([
r"\bottomrule",
r"\end{tabular}",
r"\end{table}",
])
with open(output_path, 'w') as f:
f.write("\n".join(lines))
print(f"LaTeX table saved to {output_path}")
def main():
parser = argparse.ArgumentParser(
description='Compare Bertint V3 ablation results'
)
parser.add_argument(
'--base_dir', type=str,
default='checkpoints/',
help='Base directory containing ablation_* subdirs'
)
parser.add_argument(
'--latex', type=str, default=None,
help='Optional: save LaTeX table to file'
)
args = parser.parse_args()
base_dir = Path(args.base_dir)
print(f"Loading results from {base_dir}")
results = load_results(base_dir)
if not results:
print("\n❌ No results found. Have you run the ablation training?")
return
print_table(results)
if args.latex:
save_latex(results, Path(args.latex))
if __name__ == '__main__':
main()