| """Re-render the reads and transfer score-distribution plots using the |
| v1 AMR probe's TEST-tuned thresholds (from amr_binary_v1) — i.e., the |
| deployed thresholds — instead of newly-tuned thresholds on each plot's data. |
| |
| Local-only; reads cached scores.jsonl, recomputes F1 at fixed thresholds, |
| re-renders PNG, updates summary.json. |
| |
| Usage: |
| python probes/refit_thresholds.py |
| """ |
| from __future__ import annotations |
|
|
| import json |
| from pathlib import Path |
|
|
| import numpy as np |
| import matplotlib |
|
|
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| from sklearn.metrics import roc_auc_score |
|
|
|
|
| |
| AMR_MAX_THR = 6.583248138427734 |
| AMR_MEAN_THR = -0.6747861504554749 |
|
|
|
|
| def f1_at_threshold(scores, labels, thr): |
| pred = (scores > thr).astype(int) |
| tp = int(((pred == 1) & (labels == 1)).sum()) |
| fp = int(((pred == 1) & (labels == 0)).sum()) |
| fn = int(((pred == 0) & (labels == 1)).sum()) |
| if tp + fp == 0 or tp + fn == 0: |
| return 0.0, 0.0, 0.0 |
| precision = tp / (tp + fp) |
| recall = tp / (tp + fn) |
| f1 = 2 * precision * recall / max(precision + recall, 1e-9) |
| return float(precision), float(recall), float(f1) |
|
|
|
|
| def replot(scores_jsonl: Path, png_path: Path, summary_path: Path, |
| pos_label: str, neg_label: str, suptitle: str, |
| x_axis_label: str = "per-region probe logit (= w · h_pooled + b)"): |
| rows = [json.loads(l) for l in scores_jsonl.read_text().splitlines() if l.strip()] |
| labels = np.array([r["label"] for r in rows]) |
| max_logits = np.array([r["max_logit"] for r in rows]) |
| mean_logits = np.array([r["mean_logit"] for r in rows]) |
|
|
| auc_max = float(roc_auc_score(labels, max_logits)) |
| auc_mean = float(roc_auc_score(labels, mean_logits)) |
|
|
| p_max, r_max, f1_max = f1_at_threshold(max_logits, labels, AMR_MAX_THR) |
| p_mean, r_mean, f1_mean = f1_at_threshold(mean_logits, labels, AMR_MEAN_THR) |
|
|
| print(f" max-pool: AUC={auc_max:.4f} fixed thr={AMR_MAX_THR:.3f} " |
| f"P={p_max:.3f} R={r_max:.3f} F1={f1_max:.3f}") |
| print(f" mean-pool: AUC={auc_mean:.4f} fixed thr={AMR_MEAN_THR:.3f} " |
| f"P={p_mean:.3f} R={r_mean:.3f} F1={f1_mean:.3f}") |
|
|
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) |
| for ax, scores, name, auc, fixed_thr, p, r, f1 in [ |
| (axes[0], max_logits, "max-pool", auc_max, AMR_MAX_THR, p_max, r_max, f1_max), |
| (axes[1], mean_logits, "mean-pool", auc_mean, AMR_MEAN_THR, p_mean, r_mean, f1_mean), |
| ]: |
| pos = scores[labels == 1] |
| neg = scores[labels == 0] |
| bins = np.linspace(scores.min() - 0.5, scores.max() + 0.5, 50) |
| ax.hist(neg, bins=bins, alpha=0.55, label=f"{neg_label} (n={len(neg)})", color="tab:red", density=True) |
| ax.hist(pos, bins=bins, alpha=0.55, label=f"{pos_label} (n={len(pos)})", color="tab:green", density=True) |
| rug_y = ax.get_ylim()[0] - 0.02 * (ax.get_ylim()[1] - ax.get_ylim()[0]) |
| ax.scatter(neg, np.full_like(neg, rug_y), marker="|", s=60, color="tab:red", alpha=0.5) |
| ax.scatter(pos, np.full_like(pos, rug_y * 1.5), marker="|", s=60, color="tab:green", alpha=0.5) |
| ax.axvline(0, color="grey", lw=1, ls="--", label="default boundary (logit=0)") |
| ax.axvline(fixed_thr, color="black", lw=1, ls=":", |
| label=f"AMR-vs-MISC test boundary ({fixed_thr:.2f})") |
| ax.set_xlabel(x_axis_label, fontsize=10) |
| ax.set_ylabel("density", fontsize=10) |
| ax.set_title(f"{name} • AUC={auc:.4f} • F1@fixed_thr={f1:.3f} " |
| f"(P={p:.2f}, R={r:.2f})", fontsize=10) |
| ax.legend(fontsize=8, loc="upper left") |
| ax.grid(True, alpha=0.2) |
| fig.suptitle(suptitle, fontsize=10) |
| fig.tight_layout() |
| png_path.parent.mkdir(parents=True, exist_ok=True) |
| fig.savefig(png_path, dpi=120, bbox_inches="tight") |
| plt.close(fig) |
| print(f" saved {png_path.stat().st_size/1024:.1f} KB to {png_path}") |
|
|
| summary = { |
| "fixed_thresholds_source": "amr_binary_v1_score_distributions.summary.json (TEST split of AMR-vs-MISC, the v1 probe's deployed thresholds)", |
| "n_pos": int((labels == 1).sum()), |
| "n_neg": int((labels == 0).sum()), |
| "max_pool": {"auc": auc_max, "fixed_threshold": AMR_MAX_THR, |
| "precision": p_max, "recall": r_max, "f1": f1_max}, |
| "mean_pool": {"auc": auc_mean, "fixed_threshold": AMR_MEAN_THR, |
| "precision": p_mean, "recall": r_mean, "f1": f1_mean}, |
| } |
| summary_path.write_text(json.dumps(summary, indent=2)) |
| print(f" saved summary to {summary_path}") |
|
|
|
|
| def main(): |
| results_dir = Path("/home/ror25cal/MGnify/probes/results") |
|
|
| |
| print("=== Reads plot ===") |
| replot( |
| scores_jsonl=results_dir / "probe_on_reads_score_distributions.scores.jsonl", |
| png_path=results_dir / "probe_on_reads_score_distributions.png", |
| summary_path=results_dir / "probe_on_reads_score_distributions.summary.json", |
| pos_label="AMR positive reads", |
| neg_label="matched-neg reads", |
| suptitle=( |
| "v1 AMR linear probe applied to simulated MiSeq short reads (301 bp)\n" |
| "50 pos reads from MGYG000307615_01006 (abc-f MACROLIDE, 57% ID to ref) • " |
| "50 neg reads from matched MGYG000307615_00395\n" |
| "MAG was in v1 val split — never trained on. F1 at AMR-vs-MISC test threshold (deployed)." |
| ), |
| x_axis_label="per-read probe logit (= w · h_pooled + b)", |
| ) |
|
|
| |
| print("\n=== Transfer plot (AMR vs VFDB virulence) ===") |
| replot( |
| scores_jsonl=results_dir / "transfer_amr_vs_virulence_score_distributions.scores.jsonl", |
| png_path=results_dir / "transfer_amr_vs_virulence_score_distributions.png", |
| summary_path=results_dir / "transfer_amr_vs_virulence_score_distributions.summary.json", |
| pos_label="AMR positive (CDS-only)", |
| neg_label="VFDB virulence (CDS-only)", |
| suptitle=( |
| "TRANSFER TEST: v1 AMR linear probe — AMR positives (336) vs VFDB virulence positives (336)\n" |
| "Pooled over CDS-only tokens for both classes (apples-to-apples).\n" |
| "F1 at AMR-vs-MISC test threshold (deployed) — not re-tuned on this transfer set." |
| ), |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|