"""Re-render the reads and transfer score-distribution plots using the v1 AMR probe's TEST-tuned thresholds (from amr_binary_v1) — i.e., the deployed thresholds — instead of newly-tuned thresholds on each plot's data. Local-only; reads cached scores.jsonl, recomputes F1 at fixed thresholds, re-renders PNG, updates summary.json. Usage: python probes/refit_thresholds.py """ from __future__ import annotations import json from pathlib import Path import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from sklearn.metrics import roc_auc_score # Thresholds from amr_binary_v1_score_distributions.summary.json (test set) AMR_MAX_THR = 6.583248138427734 AMR_MEAN_THR = -0.6747861504554749 def f1_at_threshold(scores, labels, thr): pred = (scores > thr).astype(int) tp = int(((pred == 1) & (labels == 1)).sum()) fp = int(((pred == 1) & (labels == 0)).sum()) fn = int(((pred == 0) & (labels == 1)).sum()) if tp + fp == 0 or tp + fn == 0: return 0.0, 0.0, 0.0 precision = tp / (tp + fp) recall = tp / (tp + fn) f1 = 2 * precision * recall / max(precision + recall, 1e-9) return float(precision), float(recall), float(f1) def replot(scores_jsonl: Path, png_path: Path, summary_path: Path, pos_label: str, neg_label: str, suptitle: str, x_axis_label: str = "per-region probe logit (= w · h_pooled + b)"): rows = [json.loads(l) for l in scores_jsonl.read_text().splitlines() if l.strip()] labels = np.array([r["label"] for r in rows]) max_logits = np.array([r["max_logit"] for r in rows]) mean_logits = np.array([r["mean_logit"] for r in rows]) auc_max = float(roc_auc_score(labels, max_logits)) auc_mean = float(roc_auc_score(labels, mean_logits)) p_max, r_max, f1_max = f1_at_threshold(max_logits, labels, AMR_MAX_THR) p_mean, r_mean, f1_mean = f1_at_threshold(mean_logits, labels, AMR_MEAN_THR) print(f" max-pool: AUC={auc_max:.4f} fixed thr={AMR_MAX_THR:.3f} " f"P={p_max:.3f} R={r_max:.3f} F1={f1_max:.3f}") print(f" mean-pool: AUC={auc_mean:.4f} fixed thr={AMR_MEAN_THR:.3f} " f"P={p_mean:.3f} R={r_mean:.3f} F1={f1_mean:.3f}") fig, axes = plt.subplots(1, 2, figsize=(14, 5)) for ax, scores, name, auc, fixed_thr, p, r, f1 in [ (axes[0], max_logits, "max-pool", auc_max, AMR_MAX_THR, p_max, r_max, f1_max), (axes[1], mean_logits, "mean-pool", auc_mean, AMR_MEAN_THR, p_mean, r_mean, f1_mean), ]: pos = scores[labels == 1] neg = scores[labels == 0] bins = np.linspace(scores.min() - 0.5, scores.max() + 0.5, 50) ax.hist(neg, bins=bins, alpha=0.55, label=f"{neg_label} (n={len(neg)})", color="tab:red", density=True) ax.hist(pos, bins=bins, alpha=0.55, label=f"{pos_label} (n={len(pos)})", color="tab:green", density=True) rug_y = ax.get_ylim()[0] - 0.02 * (ax.get_ylim()[1] - ax.get_ylim()[0]) ax.scatter(neg, np.full_like(neg, rug_y), marker="|", s=60, color="tab:red", alpha=0.5) ax.scatter(pos, np.full_like(pos, rug_y * 1.5), marker="|", s=60, color="tab:green", alpha=0.5) ax.axvline(0, color="grey", lw=1, ls="--", label="default boundary (logit=0)") ax.axvline(fixed_thr, color="black", lw=1, ls=":", label=f"AMR-vs-MISC test boundary ({fixed_thr:.2f})") ax.set_xlabel(x_axis_label, fontsize=10) ax.set_ylabel("density", fontsize=10) ax.set_title(f"{name} • AUC={auc:.4f} • F1@fixed_thr={f1:.3f} " f"(P={p:.2f}, R={r:.2f})", fontsize=10) ax.legend(fontsize=8, loc="upper left") ax.grid(True, alpha=0.2) fig.suptitle(suptitle, fontsize=10) fig.tight_layout() png_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(png_path, dpi=120, bbox_inches="tight") plt.close(fig) print(f" saved {png_path.stat().st_size/1024:.1f} KB to {png_path}") summary = { "fixed_thresholds_source": "amr_binary_v1_score_distributions.summary.json (TEST split of AMR-vs-MISC, the v1 probe's deployed thresholds)", "n_pos": int((labels == 1).sum()), "n_neg": int((labels == 0).sum()), "max_pool": {"auc": auc_max, "fixed_threshold": AMR_MAX_THR, "precision": p_max, "recall": r_max, "f1": f1_max}, "mean_pool": {"auc": auc_mean, "fixed_threshold": AMR_MEAN_THR, "precision": p_mean, "recall": r_mean, "f1": f1_mean}, } summary_path.write_text(json.dumps(summary, indent=2)) print(f" saved summary to {summary_path}") def main(): results_dir = Path("/home/ror25cal/MGnify/probes/results") # Reads plot print("=== Reads plot ===") replot( scores_jsonl=results_dir / "probe_on_reads_score_distributions.scores.jsonl", png_path=results_dir / "probe_on_reads_score_distributions.png", summary_path=results_dir / "probe_on_reads_score_distributions.summary.json", pos_label="AMR positive reads", neg_label="matched-neg reads", suptitle=( "v1 AMR linear probe applied to simulated MiSeq short reads (301 bp)\n" "50 pos reads from MGYG000307615_01006 (abc-f MACROLIDE, 57% ID to ref) • " "50 neg reads from matched MGYG000307615_00395\n" "MAG was in v1 val split — never trained on. F1 at AMR-vs-MISC test threshold (deployed)." ), x_axis_label="per-read probe logit (= w · h_pooled + b)", ) # Transfer plot print("\n=== Transfer plot (AMR vs VFDB virulence) ===") replot( scores_jsonl=results_dir / "transfer_amr_vs_virulence_score_distributions.scores.jsonl", png_path=results_dir / "transfer_amr_vs_virulence_score_distributions.png", summary_path=results_dir / "transfer_amr_vs_virulence_score_distributions.summary.json", pos_label="AMR positive (CDS-only)", neg_label="VFDB virulence (CDS-only)", suptitle=( "TRANSFER TEST: v1 AMR linear probe — AMR positives (336) vs VFDB virulence positives (336)\n" "Pooled over CDS-only tokens for both classes (apples-to-apples).\n" "F1 at AMR-vs-MISC test threshold (deployed) — not re-tuned on this transfer set." ), ) if __name__ == "__main__": main()