mgnify-evo2-probes / code /probes /refit_thresholds.py
JG1310's picture
Probe artifacts: code, manifests, plots, scores, summaries, checkpoints, docs
eb69de4 verified
"""Re-render the reads and transfer score-distribution plots using the
v1 AMR probe's TEST-tuned thresholds (from amr_binary_v1) — i.e., the
deployed thresholds — instead of newly-tuned thresholds on each plot's data.
Local-only; reads cached scores.jsonl, recomputes F1 at fixed thresholds,
re-renders PNG, updates summary.json.
Usage:
python probes/refit_thresholds.py
"""
from __future__ import annotations
import json
from pathlib import Path
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
# Thresholds from amr_binary_v1_score_distributions.summary.json (test set)
AMR_MAX_THR = 6.583248138427734
AMR_MEAN_THR = -0.6747861504554749
def f1_at_threshold(scores, labels, thr):
pred = (scores > thr).astype(int)
tp = int(((pred == 1) & (labels == 1)).sum())
fp = int(((pred == 1) & (labels == 0)).sum())
fn = int(((pred == 0) & (labels == 1)).sum())
if tp + fp == 0 or tp + fn == 0:
return 0.0, 0.0, 0.0
precision = tp / (tp + fp)
recall = tp / (tp + fn)
f1 = 2 * precision * recall / max(precision + recall, 1e-9)
return float(precision), float(recall), float(f1)
def replot(scores_jsonl: Path, png_path: Path, summary_path: Path,
pos_label: str, neg_label: str, suptitle: str,
x_axis_label: str = "per-region probe logit (= w · h_pooled + b)"):
rows = [json.loads(l) for l in scores_jsonl.read_text().splitlines() if l.strip()]
labels = np.array([r["label"] for r in rows])
max_logits = np.array([r["max_logit"] for r in rows])
mean_logits = np.array([r["mean_logit"] for r in rows])
auc_max = float(roc_auc_score(labels, max_logits))
auc_mean = float(roc_auc_score(labels, mean_logits))
p_max, r_max, f1_max = f1_at_threshold(max_logits, labels, AMR_MAX_THR)
p_mean, r_mean, f1_mean = f1_at_threshold(mean_logits, labels, AMR_MEAN_THR)
print(f" max-pool: AUC={auc_max:.4f} fixed thr={AMR_MAX_THR:.3f} "
f"P={p_max:.3f} R={r_max:.3f} F1={f1_max:.3f}")
print(f" mean-pool: AUC={auc_mean:.4f} fixed thr={AMR_MEAN_THR:.3f} "
f"P={p_mean:.3f} R={r_mean:.3f} F1={f1_mean:.3f}")
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for ax, scores, name, auc, fixed_thr, p, r, f1 in [
(axes[0], max_logits, "max-pool", auc_max, AMR_MAX_THR, p_max, r_max, f1_max),
(axes[1], mean_logits, "mean-pool", auc_mean, AMR_MEAN_THR, p_mean, r_mean, f1_mean),
]:
pos = scores[labels == 1]
neg = scores[labels == 0]
bins = np.linspace(scores.min() - 0.5, scores.max() + 0.5, 50)
ax.hist(neg, bins=bins, alpha=0.55, label=f"{neg_label} (n={len(neg)})", color="tab:red", density=True)
ax.hist(pos, bins=bins, alpha=0.55, label=f"{pos_label} (n={len(pos)})", color="tab:green", density=True)
rug_y = ax.get_ylim()[0] - 0.02 * (ax.get_ylim()[1] - ax.get_ylim()[0])
ax.scatter(neg, np.full_like(neg, rug_y), marker="|", s=60, color="tab:red", alpha=0.5)
ax.scatter(pos, np.full_like(pos, rug_y * 1.5), marker="|", s=60, color="tab:green", alpha=0.5)
ax.axvline(0, color="grey", lw=1, ls="--", label="default boundary (logit=0)")
ax.axvline(fixed_thr, color="black", lw=1, ls=":",
label=f"AMR-vs-MISC test boundary ({fixed_thr:.2f})")
ax.set_xlabel(x_axis_label, fontsize=10)
ax.set_ylabel("density", fontsize=10)
ax.set_title(f"{name} • AUC={auc:.4f} • F1@fixed_thr={f1:.3f} "
f"(P={p:.2f}, R={r:.2f})", fontsize=10)
ax.legend(fontsize=8, loc="upper left")
ax.grid(True, alpha=0.2)
fig.suptitle(suptitle, fontsize=10)
fig.tight_layout()
png_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(png_path, dpi=120, bbox_inches="tight")
plt.close(fig)
print(f" saved {png_path.stat().st_size/1024:.1f} KB to {png_path}")
summary = {
"fixed_thresholds_source": "amr_binary_v1_score_distributions.summary.json (TEST split of AMR-vs-MISC, the v1 probe's deployed thresholds)",
"n_pos": int((labels == 1).sum()),
"n_neg": int((labels == 0).sum()),
"max_pool": {"auc": auc_max, "fixed_threshold": AMR_MAX_THR,
"precision": p_max, "recall": r_max, "f1": f1_max},
"mean_pool": {"auc": auc_mean, "fixed_threshold": AMR_MEAN_THR,
"precision": p_mean, "recall": r_mean, "f1": f1_mean},
}
summary_path.write_text(json.dumps(summary, indent=2))
print(f" saved summary to {summary_path}")
def main():
results_dir = Path("/home/ror25cal/MGnify/probes/results")
# Reads plot
print("=== Reads plot ===")
replot(
scores_jsonl=results_dir / "probe_on_reads_score_distributions.scores.jsonl",
png_path=results_dir / "probe_on_reads_score_distributions.png",
summary_path=results_dir / "probe_on_reads_score_distributions.summary.json",
pos_label="AMR positive reads",
neg_label="matched-neg reads",
suptitle=(
"v1 AMR linear probe applied to simulated MiSeq short reads (301 bp)\n"
"50 pos reads from MGYG000307615_01006 (abc-f MACROLIDE, 57% ID to ref) • "
"50 neg reads from matched MGYG000307615_00395\n"
"MAG was in v1 val split — never trained on. F1 at AMR-vs-MISC test threshold (deployed)."
),
x_axis_label="per-read probe logit (= w · h_pooled + b)",
)
# Transfer plot
print("\n=== Transfer plot (AMR vs VFDB virulence) ===")
replot(
scores_jsonl=results_dir / "transfer_amr_vs_virulence_score_distributions.scores.jsonl",
png_path=results_dir / "transfer_amr_vs_virulence_score_distributions.png",
summary_path=results_dir / "transfer_amr_vs_virulence_score_distributions.summary.json",
pos_label="AMR positive (CDS-only)",
neg_label="VFDB virulence (CDS-only)",
suptitle=(
"TRANSFER TEST: v1 AMR linear probe — AMR positives (336) vs VFDB virulence positives (336)\n"
"Pooled over CDS-only tokens for both classes (apples-to-apples).\n"
"F1 at AMR-vs-MISC test threshold (deployed) — not re-tuned on this transfer set."
),
)
if __name__ == "__main__":
main()