| """Re-render the multi-org-reads score-distribution plot showing both: |
| (a) the v1 AMR-vs-MISC deployed threshold (calibrated on 5kb extracts) |
| (b) the short-read-optimal threshold (calibrated on these reads themselves) |
| |
| Pitch story: a single linear probe works on short reads — the only thing |
| that needs to change between full-length and short-read deployment is the |
| *threshold*, not the model itself. |
| |
| Local-only; reads cached scores.jsonl, no Modal needed. |
| """ |
| from __future__ import annotations |
|
|
| import json |
| from pathlib import Path |
|
|
| import numpy as np |
| import matplotlib |
|
|
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| from sklearn.metrics import roc_auc_score, precision_recall_curve |
|
|
|
|
| SCORES = Path("/home/ror25cal/MGnify/probes/results/multi_org_reads_score_distributions.scores.jsonl") |
| OUT_PNG = Path("/home/ror25cal/MGnify/probes/results/multi_org_reads_score_distributions.png") |
|
|
| |
| DEPLOYED_THR_MAX = 6.583248138427734 |
| DEPLOYED_THR_MEAN = -0.6747861504554749 |
|
|
|
|
| def f1_at(scores, labels, thr): |
| pred = (scores > thr).astype(int) |
| tp = int(((pred == 1) & (labels == 1)).sum()) |
| fp = int(((pred == 1) & (labels == 0)).sum()) |
| fn = int(((pred == 0) & (labels == 1)).sum()) |
| if tp + fp == 0 or tp + fn == 0: |
| return 0.0, 0.0, 0.0 |
| p = tp / (tp + fp) |
| r = tp / (tp + fn) |
| return float(p), float(r), float(2 * p * r / max(p + r, 1e-9)) |
|
|
|
|
| def best_f1_threshold(scores, labels): |
| prec, rec, thr = precision_recall_curve(labels, scores) |
| f1 = 2 * prec * rec / np.maximum(prec + rec, 1e-9) |
| idx = int(np.argmax(f1)) |
| return float(thr[min(idx, len(thr) - 1)]), float(f1[idx]) |
|
|
|
|
| def main(): |
| rows = [json.loads(l) for l in SCORES.read_text().splitlines() if l.strip()] |
| labels = np.array([r["label"] for r in rows]) |
| max_logits = np.array([r["max_logit"] for r in rows]) |
| mean_logits = np.array([r["mean_logit"] for r in rows]) |
|
|
| auc_max = float(roc_auc_score(labels, max_logits)) |
| auc_mean = float(roc_auc_score(labels, mean_logits)) |
|
|
| |
| p_dep_max, r_dep_max, f1_dep_max = f1_at(max_logits, labels, DEPLOYED_THR_MAX) |
| p_dep_mean, r_dep_mean, f1_dep_mean = f1_at(mean_logits, labels, DEPLOYED_THR_MEAN) |
|
|
| |
| sr_thr_max, _ = best_f1_threshold(max_logits, labels) |
| p_sr_max, r_sr_max, f1_sr_max = f1_at(max_logits, labels, sr_thr_max) |
| sr_thr_mean, _ = best_f1_threshold(mean_logits, labels) |
| p_sr_mean, r_sr_mean, f1_sr_mean = f1_at(mean_logits, labels, sr_thr_mean) |
|
|
| print(f"max-pool:") |
| print(f" AUC (threshold-free): {auc_max:.4f}") |
| print(f" Deployed thr ({DEPLOYED_THR_MAX:.2f}, full-length): F1={f1_dep_max:.3f} P={p_dep_max:.2f} R={r_dep_max:.2f}") |
| print(f" Short-read optimal thr ({sr_thr_max:.2f}): F1={f1_sr_max:.3f} P={p_sr_max:.2f} R={r_sr_max:.2f}") |
| print(f"mean-pool:") |
| print(f" AUC: {auc_mean:.4f}") |
| print(f" Deployed thr ({DEPLOYED_THR_MEAN:+.2f}, full-length): F1={f1_dep_mean:.3f} P={p_dep_mean:.2f} R={r_dep_mean:.2f}") |
| print(f" Short-read optimal thr ({sr_thr_mean:+.2f}): F1={f1_sr_mean:.3f} P={p_sr_mean:.2f} R={r_sr_mean:.2f}") |
|
|
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) |
| panel_data = [ |
| (axes[0], max_logits, "max-pool", auc_max, |
| DEPLOYED_THR_MAX, f1_dep_max, p_dep_max, r_dep_max, |
| sr_thr_max, f1_sr_max, p_sr_max, r_sr_max), |
| (axes[1], mean_logits, "mean-pool", auc_mean, |
| DEPLOYED_THR_MEAN, f1_dep_mean, p_dep_mean, r_dep_mean, |
| sr_thr_mean, f1_sr_mean, p_sr_mean, r_sr_mean), |
| ] |
| for (ax, scores, name, auc, |
| dep_thr, f1_dep, p_dep, r_dep, |
| sr_thr, f1_sr, p_sr, r_sr) in panel_data: |
| pos = scores[labels == 1] |
| neg = scores[labels == 0] |
| bins = np.linspace(scores.min() - 0.5, scores.max() + 0.5, 60) |
| ax.hist(neg, bins=bins, alpha=0.55, label=f"matched-neg reads (n={len(neg)})", color="tab:red", density=True) |
| ax.hist(pos, bins=bins, alpha=0.55, label=f"AMR positive reads (n={len(pos)})", color="tab:green", density=True) |
| rug_y = ax.get_ylim()[0] - 0.02 * (ax.get_ylim()[1] - ax.get_ylim()[0]) |
| ax.scatter(neg, np.full_like(neg, rug_y), marker="|", s=20, color="tab:red", alpha=0.4) |
| ax.scatter(pos, np.full_like(pos, rug_y * 1.5), marker="|", s=20, color="tab:green", alpha=0.4) |
| ax.axvline(0, color="grey", lw=1, ls="--", label="default boundary (logit=0)") |
| |
| ax.axvline(dep_thr, color="black", lw=1.2, ls=":", alpha=0.55, |
| label=f"v1 deployed thr ({dep_thr:.2f}) → F1={f1_dep:.3f} (R={r_dep:.2f})") |
| |
| ax.axvline(sr_thr, color="tab:blue", lw=1.5, ls="-", |
| label=f"short-read optimal thr ({sr_thr:.2f}) → F1={f1_sr:.3f} (R={r_sr:.2f})") |
| ax.set_xlabel("per-read probe logit (= w · h_pooled + b)", fontsize=10) |
| ax.set_ylabel("density", fontsize=10) |
| ax.set_title(f"{name} • AUC={auc:.4f} • short-read F1={f1_sr:.3f} " |
| f"(deployed F1={f1_dep:.3f})", fontsize=10) |
| ax.legend(fontsize=8, loc="upper left") |
| ax.grid(True, alpha=0.2) |
|
|
| fig.suptitle( |
| "v1 AMR linear probe applied to multi-organism short reads (301 bp simulated MiSeq)\n" |
| "Same probe, two thresholds: v1's full-length-tuned threshold (dotted, gray) vs short-read-optimal (solid, blue).\n" |
| "AUC is threshold-free and unaffected; only F1 changes. Recalibrating threshold for short reads recovers most of the deployment F1.", |
| fontsize=10, |
| ) |
| fig.tight_layout() |
| OUT_PNG.parent.mkdir(parents=True, exist_ok=True) |
| fig.savefig(OUT_PNG, dpi=120, bbox_inches="tight") |
| plt.close(fig) |
| print(f"\nsaved {OUT_PNG.stat().st_size/1024:.1f} KB to {OUT_PNG}") |
|
|
| |
| summary = { |
| "n_reads": len(rows), |
| "n_pos_reads": int((labels == 1).sum()), |
| "n_neg_reads": int((labels == 0).sum()), |
| "max_pool": { |
| "auc": auc_max, |
| "deployed_threshold": {"thr": DEPLOYED_THR_MAX, "f1": f1_dep_max, |
| "precision": p_dep_max, "recall": r_dep_max, |
| "source": "amr_binary_v1 test"}, |
| "short_read_optimal_threshold": {"thr": sr_thr_max, "f1": f1_sr_max, |
| "precision": p_sr_max, "recall": r_sr_max, |
| "source": "this-data tuning (caveat: same set used for selection and evaluation)"}, |
| }, |
| "mean_pool": { |
| "auc": auc_mean, |
| "deployed_threshold": {"thr": DEPLOYED_THR_MEAN, "f1": f1_dep_mean, |
| "precision": p_dep_mean, "recall": r_dep_mean, |
| "source": "amr_binary_v1 test"}, |
| "short_read_optimal_threshold": {"thr": sr_thr_mean, "f1": f1_sr_mean, |
| "precision": p_sr_mean, "recall": r_sr_mean, |
| "source": "this-data tuning (caveat: same set used for selection and evaluation)"}, |
| }, |
| } |
| out_summary = OUT_PNG.with_suffix(".summary.json") |
| out_summary.write_text(json.dumps(summary, indent=2)) |
| print(f"saved summary to {out_summary}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|