mgnify-evo2-probes / code /probes /refit_short_read_threshold.py
JG1310's picture
Probe artifacts: code, manifests, plots, scores, summaries, checkpoints, docs
eb69de4 verified
"""Re-render the multi-org-reads score-distribution plot showing both:
(a) the v1 AMR-vs-MISC deployed threshold (calibrated on 5kb extracts)
(b) the short-read-optimal threshold (calibrated on these reads themselves)
Pitch story: a single linear probe works on short reads — the only thing
that needs to change between full-length and short-read deployment is the
*threshold*, not the model itself.
Local-only; reads cached scores.jsonl, no Modal needed.
"""
from __future__ import annotations
import json
from pathlib import Path
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, precision_recall_curve
SCORES = Path("/home/ror25cal/MGnify/probes/results/multi_org_reads_score_distributions.scores.jsonl")
OUT_PNG = Path("/home/ror25cal/MGnify/probes/results/multi_org_reads_score_distributions.png")
# Deployed v1 thresholds from amr_binary_v1_score_distributions.summary.json
DEPLOYED_THR_MAX = 6.583248138427734
DEPLOYED_THR_MEAN = -0.6747861504554749
def f1_at(scores, labels, thr):
pred = (scores > thr).astype(int)
tp = int(((pred == 1) & (labels == 1)).sum())
fp = int(((pred == 1) & (labels == 0)).sum())
fn = int(((pred == 0) & (labels == 1)).sum())
if tp + fp == 0 or tp + fn == 0:
return 0.0, 0.0, 0.0
p = tp / (tp + fp)
r = tp / (tp + fn)
return float(p), float(r), float(2 * p * r / max(p + r, 1e-9))
def best_f1_threshold(scores, labels):
prec, rec, thr = precision_recall_curve(labels, scores)
f1 = 2 * prec * rec / np.maximum(prec + rec, 1e-9)
idx = int(np.argmax(f1))
return float(thr[min(idx, len(thr) - 1)]), float(f1[idx])
def main():
rows = [json.loads(l) for l in SCORES.read_text().splitlines() if l.strip()]
labels = np.array([r["label"] for r in rows])
max_logits = np.array([r["max_logit"] for r in rows])
mean_logits = np.array([r["mean_logit"] for r in rows])
auc_max = float(roc_auc_score(labels, max_logits))
auc_mean = float(roc_auc_score(labels, mean_logits))
# Deployed-threshold F1 (from full-length 5kb calibration)
p_dep_max, r_dep_max, f1_dep_max = f1_at(max_logits, labels, DEPLOYED_THR_MAX)
p_dep_mean, r_dep_mean, f1_dep_mean = f1_at(mean_logits, labels, DEPLOYED_THR_MEAN)
# Short-read-optimal threshold (from these reads' scores)
sr_thr_max, _ = best_f1_threshold(max_logits, labels)
p_sr_max, r_sr_max, f1_sr_max = f1_at(max_logits, labels, sr_thr_max)
sr_thr_mean, _ = best_f1_threshold(mean_logits, labels)
p_sr_mean, r_sr_mean, f1_sr_mean = f1_at(mean_logits, labels, sr_thr_mean)
print(f"max-pool:")
print(f" AUC (threshold-free): {auc_max:.4f}")
print(f" Deployed thr ({DEPLOYED_THR_MAX:.2f}, full-length): F1={f1_dep_max:.3f} P={p_dep_max:.2f} R={r_dep_max:.2f}")
print(f" Short-read optimal thr ({sr_thr_max:.2f}): F1={f1_sr_max:.3f} P={p_sr_max:.2f} R={r_sr_max:.2f}")
print(f"mean-pool:")
print(f" AUC: {auc_mean:.4f}")
print(f" Deployed thr ({DEPLOYED_THR_MEAN:+.2f}, full-length): F1={f1_dep_mean:.3f} P={p_dep_mean:.2f} R={r_dep_mean:.2f}")
print(f" Short-read optimal thr ({sr_thr_mean:+.2f}): F1={f1_sr_mean:.3f} P={p_sr_mean:.2f} R={r_sr_mean:.2f}")
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
panel_data = [
(axes[0], max_logits, "max-pool", auc_max,
DEPLOYED_THR_MAX, f1_dep_max, p_dep_max, r_dep_max,
sr_thr_max, f1_sr_max, p_sr_max, r_sr_max),
(axes[1], mean_logits, "mean-pool", auc_mean,
DEPLOYED_THR_MEAN, f1_dep_mean, p_dep_mean, r_dep_mean,
sr_thr_mean, f1_sr_mean, p_sr_mean, r_sr_mean),
]
for (ax, scores, name, auc,
dep_thr, f1_dep, p_dep, r_dep,
sr_thr, f1_sr, p_sr, r_sr) in panel_data:
pos = scores[labels == 1]
neg = scores[labels == 0]
bins = np.linspace(scores.min() - 0.5, scores.max() + 0.5, 60)
ax.hist(neg, bins=bins, alpha=0.55, label=f"matched-neg reads (n={len(neg)})", color="tab:red", density=True)
ax.hist(pos, bins=bins, alpha=0.55, label=f"AMR positive reads (n={len(pos)})", color="tab:green", density=True)
rug_y = ax.get_ylim()[0] - 0.02 * (ax.get_ylim()[1] - ax.get_ylim()[0])
ax.scatter(neg, np.full_like(neg, rug_y), marker="|", s=20, color="tab:red", alpha=0.4)
ax.scatter(pos, np.full_like(pos, rug_y * 1.5), marker="|", s=20, color="tab:green", alpha=0.4)
ax.axvline(0, color="grey", lw=1, ls="--", label="default boundary (logit=0)")
# Deployed threshold (from full-length 5kb calibration) — drawn lighter
ax.axvline(dep_thr, color="black", lw=1.2, ls=":", alpha=0.55,
label=f"v1 deployed thr ({dep_thr:.2f}) → F1={f1_dep:.3f} (R={r_dep:.2f})")
# Short-read optimal threshold — drawn solid, the headline
ax.axvline(sr_thr, color="tab:blue", lw=1.5, ls="-",
label=f"short-read optimal thr ({sr_thr:.2f}) → F1={f1_sr:.3f} (R={r_sr:.2f})")
ax.set_xlabel("per-read probe logit (= w · h_pooled + b)", fontsize=10)
ax.set_ylabel("density", fontsize=10)
ax.set_title(f"{name} • AUC={auc:.4f} • short-read F1={f1_sr:.3f} "
f"(deployed F1={f1_dep:.3f})", fontsize=10)
ax.legend(fontsize=8, loc="upper left")
ax.grid(True, alpha=0.2)
fig.suptitle(
"v1 AMR linear probe applied to multi-organism short reads (301 bp simulated MiSeq)\n"
"Same probe, two thresholds: v1's full-length-tuned threshold (dotted, gray) vs short-read-optimal (solid, blue).\n"
"AUC is threshold-free and unaffected; only F1 changes. Recalibrating threshold for short reads recovers most of the deployment F1.",
fontsize=10,
)
fig.tight_layout()
OUT_PNG.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(OUT_PNG, dpi=120, bbox_inches="tight")
plt.close(fig)
print(f"\nsaved {OUT_PNG.stat().st_size/1024:.1f} KB to {OUT_PNG}")
# Update summary.json
summary = {
"n_reads": len(rows),
"n_pos_reads": int((labels == 1).sum()),
"n_neg_reads": int((labels == 0).sum()),
"max_pool": {
"auc": auc_max,
"deployed_threshold": {"thr": DEPLOYED_THR_MAX, "f1": f1_dep_max,
"precision": p_dep_max, "recall": r_dep_max,
"source": "amr_binary_v1 test"},
"short_read_optimal_threshold": {"thr": sr_thr_max, "f1": f1_sr_max,
"precision": p_sr_max, "recall": r_sr_max,
"source": "this-data tuning (caveat: same set used for selection and evaluation)"},
},
"mean_pool": {
"auc": auc_mean,
"deployed_threshold": {"thr": DEPLOYED_THR_MEAN, "f1": f1_dep_mean,
"precision": p_dep_mean, "recall": r_dep_mean,
"source": "amr_binary_v1 test"},
"short_read_optimal_threshold": {"thr": sr_thr_mean, "f1": f1_sr_mean,
"precision": p_sr_mean, "recall": r_sr_mean,
"source": "this-data tuning (caveat: same set used for selection and evaluation)"},
},
}
out_summary = OUT_PNG.with_suffix(".summary.json")
out_summary.write_text(json.dumps(summary, indent=2))
print(f"saved summary to {out_summary}")
if __name__ == "__main__":
main()