"""Re-render the multi-org-reads score-distribution plot showing both: (a) the v1 AMR-vs-MISC deployed threshold (calibrated on 5kb extracts) (b) the short-read-optimal threshold (calibrated on these reads themselves) Pitch story: a single linear probe works on short reads — the only thing that needs to change between full-length and short-read deployment is the *threshold*, not the model itself. Local-only; reads cached scores.jsonl, no Modal needed. """ from __future__ import annotations import json from pathlib import Path import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from sklearn.metrics import roc_auc_score, precision_recall_curve SCORES = Path("/home/ror25cal/MGnify/probes/results/multi_org_reads_score_distributions.scores.jsonl") OUT_PNG = Path("/home/ror25cal/MGnify/probes/results/multi_org_reads_score_distributions.png") # Deployed v1 thresholds from amr_binary_v1_score_distributions.summary.json DEPLOYED_THR_MAX = 6.583248138427734 DEPLOYED_THR_MEAN = -0.6747861504554749 def f1_at(scores, labels, thr): pred = (scores > thr).astype(int) tp = int(((pred == 1) & (labels == 1)).sum()) fp = int(((pred == 1) & (labels == 0)).sum()) fn = int(((pred == 0) & (labels == 1)).sum()) if tp + fp == 0 or tp + fn == 0: return 0.0, 0.0, 0.0 p = tp / (tp + fp) r = tp / (tp + fn) return float(p), float(r), float(2 * p * r / max(p + r, 1e-9)) def best_f1_threshold(scores, labels): prec, rec, thr = precision_recall_curve(labels, scores) f1 = 2 * prec * rec / np.maximum(prec + rec, 1e-9) idx = int(np.argmax(f1)) return float(thr[min(idx, len(thr) - 1)]), float(f1[idx]) def main(): rows = [json.loads(l) for l in SCORES.read_text().splitlines() if l.strip()] labels = np.array([r["label"] for r in rows]) max_logits = np.array([r["max_logit"] for r in rows]) mean_logits = np.array([r["mean_logit"] for r in rows]) auc_max = float(roc_auc_score(labels, max_logits)) auc_mean = float(roc_auc_score(labels, mean_logits)) # Deployed-threshold F1 (from full-length 5kb calibration) p_dep_max, r_dep_max, f1_dep_max = f1_at(max_logits, labels, DEPLOYED_THR_MAX) p_dep_mean, r_dep_mean, f1_dep_mean = f1_at(mean_logits, labels, DEPLOYED_THR_MEAN) # Short-read-optimal threshold (from these reads' scores) sr_thr_max, _ = best_f1_threshold(max_logits, labels) p_sr_max, r_sr_max, f1_sr_max = f1_at(max_logits, labels, sr_thr_max) sr_thr_mean, _ = best_f1_threshold(mean_logits, labels) p_sr_mean, r_sr_mean, f1_sr_mean = f1_at(mean_logits, labels, sr_thr_mean) print(f"max-pool:") print(f" AUC (threshold-free): {auc_max:.4f}") print(f" Deployed thr ({DEPLOYED_THR_MAX:.2f}, full-length): F1={f1_dep_max:.3f} P={p_dep_max:.2f} R={r_dep_max:.2f}") print(f" Short-read optimal thr ({sr_thr_max:.2f}): F1={f1_sr_max:.3f} P={p_sr_max:.2f} R={r_sr_max:.2f}") print(f"mean-pool:") print(f" AUC: {auc_mean:.4f}") print(f" Deployed thr ({DEPLOYED_THR_MEAN:+.2f}, full-length): F1={f1_dep_mean:.3f} P={p_dep_mean:.2f} R={r_dep_mean:.2f}") print(f" Short-read optimal thr ({sr_thr_mean:+.2f}): F1={f1_sr_mean:.3f} P={p_sr_mean:.2f} R={r_sr_mean:.2f}") fig, axes = plt.subplots(1, 2, figsize=(14, 5)) panel_data = [ (axes[0], max_logits, "max-pool", auc_max, DEPLOYED_THR_MAX, f1_dep_max, p_dep_max, r_dep_max, sr_thr_max, f1_sr_max, p_sr_max, r_sr_max), (axes[1], mean_logits, "mean-pool", auc_mean, DEPLOYED_THR_MEAN, f1_dep_mean, p_dep_mean, r_dep_mean, sr_thr_mean, f1_sr_mean, p_sr_mean, r_sr_mean), ] for (ax, scores, name, auc, dep_thr, f1_dep, p_dep, r_dep, sr_thr, f1_sr, p_sr, r_sr) in panel_data: pos = scores[labels == 1] neg = scores[labels == 0] bins = np.linspace(scores.min() - 0.5, scores.max() + 0.5, 60) ax.hist(neg, bins=bins, alpha=0.55, label=f"matched-neg reads (n={len(neg)})", color="tab:red", density=True) ax.hist(pos, bins=bins, alpha=0.55, label=f"AMR positive reads (n={len(pos)})", color="tab:green", density=True) rug_y = ax.get_ylim()[0] - 0.02 * (ax.get_ylim()[1] - ax.get_ylim()[0]) ax.scatter(neg, np.full_like(neg, rug_y), marker="|", s=20, color="tab:red", alpha=0.4) ax.scatter(pos, np.full_like(pos, rug_y * 1.5), marker="|", s=20, color="tab:green", alpha=0.4) ax.axvline(0, color="grey", lw=1, ls="--", label="default boundary (logit=0)") # Deployed threshold (from full-length 5kb calibration) — drawn lighter ax.axvline(dep_thr, color="black", lw=1.2, ls=":", alpha=0.55, label=f"v1 deployed thr ({dep_thr:.2f}) → F1={f1_dep:.3f} (R={r_dep:.2f})") # Short-read optimal threshold — drawn solid, the headline ax.axvline(sr_thr, color="tab:blue", lw=1.5, ls="-", label=f"short-read optimal thr ({sr_thr:.2f}) → F1={f1_sr:.3f} (R={r_sr:.2f})") ax.set_xlabel("per-read probe logit (= w · h_pooled + b)", fontsize=10) ax.set_ylabel("density", fontsize=10) ax.set_title(f"{name} • AUC={auc:.4f} • short-read F1={f1_sr:.3f} " f"(deployed F1={f1_dep:.3f})", fontsize=10) ax.legend(fontsize=8, loc="upper left") ax.grid(True, alpha=0.2) fig.suptitle( "v1 AMR linear probe applied to multi-organism short reads (301 bp simulated MiSeq)\n" "Same probe, two thresholds: v1's full-length-tuned threshold (dotted, gray) vs short-read-optimal (solid, blue).\n" "AUC is threshold-free and unaffected; only F1 changes. Recalibrating threshold for short reads recovers most of the deployment F1.", fontsize=10, ) fig.tight_layout() OUT_PNG.parent.mkdir(parents=True, exist_ok=True) fig.savefig(OUT_PNG, dpi=120, bbox_inches="tight") plt.close(fig) print(f"\nsaved {OUT_PNG.stat().st_size/1024:.1f} KB to {OUT_PNG}") # Update summary.json summary = { "n_reads": len(rows), "n_pos_reads": int((labels == 1).sum()), "n_neg_reads": int((labels == 0).sum()), "max_pool": { "auc": auc_max, "deployed_threshold": {"thr": DEPLOYED_THR_MAX, "f1": f1_dep_max, "precision": p_dep_max, "recall": r_dep_max, "source": "amr_binary_v1 test"}, "short_read_optimal_threshold": {"thr": sr_thr_max, "f1": f1_sr_max, "precision": p_sr_max, "recall": r_sr_max, "source": "this-data tuning (caveat: same set used for selection and evaluation)"}, }, "mean_pool": { "auc": auc_mean, "deployed_threshold": {"thr": DEPLOYED_THR_MEAN, "f1": f1_dep_mean, "precision": p_dep_mean, "recall": r_dep_mean, "source": "amr_binary_v1 test"}, "short_read_optimal_threshold": {"thr": sr_thr_mean, "f1": f1_sr_mean, "precision": p_sr_mean, "recall": r_sr_mean, "source": "this-data tuning (caveat: same set used for selection and evaluation)"}, }, } out_summary = OUT_PNG.with_suffix(".summary.json") out_summary.write_text(json.dumps(summary, indent=2)) print(f"saved summary to {out_summary}") if __name__ == "__main__": main()