File size: 6,365 Bytes
eb69de4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | """Re-render the reads and transfer score-distribution plots using the
v1 AMR probe's TEST-tuned thresholds (from amr_binary_v1) — i.e., the
deployed thresholds — instead of newly-tuned thresholds on each plot's data.
Local-only; reads cached scores.jsonl, recomputes F1 at fixed thresholds,
re-renders PNG, updates summary.json.
Usage:
python probes/refit_thresholds.py
"""
from __future__ import annotations
import json
from pathlib import Path
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
# Thresholds from amr_binary_v1_score_distributions.summary.json (test set)
AMR_MAX_THR = 6.583248138427734
AMR_MEAN_THR = -0.6747861504554749
def f1_at_threshold(scores, labels, thr):
pred = (scores > thr).astype(int)
tp = int(((pred == 1) & (labels == 1)).sum())
fp = int(((pred == 1) & (labels == 0)).sum())
fn = int(((pred == 0) & (labels == 1)).sum())
if tp + fp == 0 or tp + fn == 0:
return 0.0, 0.0, 0.0
precision = tp / (tp + fp)
recall = tp / (tp + fn)
f1 = 2 * precision * recall / max(precision + recall, 1e-9)
return float(precision), float(recall), float(f1)
def replot(scores_jsonl: Path, png_path: Path, summary_path: Path,
pos_label: str, neg_label: str, suptitle: str,
x_axis_label: str = "per-region probe logit (= w · h_pooled + b)"):
rows = [json.loads(l) for l in scores_jsonl.read_text().splitlines() if l.strip()]
labels = np.array([r["label"] for r in rows])
max_logits = np.array([r["max_logit"] for r in rows])
mean_logits = np.array([r["mean_logit"] for r in rows])
auc_max = float(roc_auc_score(labels, max_logits))
auc_mean = float(roc_auc_score(labels, mean_logits))
p_max, r_max, f1_max = f1_at_threshold(max_logits, labels, AMR_MAX_THR)
p_mean, r_mean, f1_mean = f1_at_threshold(mean_logits, labels, AMR_MEAN_THR)
print(f" max-pool: AUC={auc_max:.4f} fixed thr={AMR_MAX_THR:.3f} "
f"P={p_max:.3f} R={r_max:.3f} F1={f1_max:.3f}")
print(f" mean-pool: AUC={auc_mean:.4f} fixed thr={AMR_MEAN_THR:.3f} "
f"P={p_mean:.3f} R={r_mean:.3f} F1={f1_mean:.3f}")
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for ax, scores, name, auc, fixed_thr, p, r, f1 in [
(axes[0], max_logits, "max-pool", auc_max, AMR_MAX_THR, p_max, r_max, f1_max),
(axes[1], mean_logits, "mean-pool", auc_mean, AMR_MEAN_THR, p_mean, r_mean, f1_mean),
]:
pos = scores[labels == 1]
neg = scores[labels == 0]
bins = np.linspace(scores.min() - 0.5, scores.max() + 0.5, 50)
ax.hist(neg, bins=bins, alpha=0.55, label=f"{neg_label} (n={len(neg)})", color="tab:red", density=True)
ax.hist(pos, bins=bins, alpha=0.55, label=f"{pos_label} (n={len(pos)})", color="tab:green", density=True)
rug_y = ax.get_ylim()[0] - 0.02 * (ax.get_ylim()[1] - ax.get_ylim()[0])
ax.scatter(neg, np.full_like(neg, rug_y), marker="|", s=60, color="tab:red", alpha=0.5)
ax.scatter(pos, np.full_like(pos, rug_y * 1.5), marker="|", s=60, color="tab:green", alpha=0.5)
ax.axvline(0, color="grey", lw=1, ls="--", label="default boundary (logit=0)")
ax.axvline(fixed_thr, color="black", lw=1, ls=":",
label=f"AMR-vs-MISC test boundary ({fixed_thr:.2f})")
ax.set_xlabel(x_axis_label, fontsize=10)
ax.set_ylabel("density", fontsize=10)
ax.set_title(f"{name} • AUC={auc:.4f} • F1@fixed_thr={f1:.3f} "
f"(P={p:.2f}, R={r:.2f})", fontsize=10)
ax.legend(fontsize=8, loc="upper left")
ax.grid(True, alpha=0.2)
fig.suptitle(suptitle, fontsize=10)
fig.tight_layout()
png_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(png_path, dpi=120, bbox_inches="tight")
plt.close(fig)
print(f" saved {png_path.stat().st_size/1024:.1f} KB to {png_path}")
summary = {
"fixed_thresholds_source": "amr_binary_v1_score_distributions.summary.json (TEST split of AMR-vs-MISC, the v1 probe's deployed thresholds)",
"n_pos": int((labels == 1).sum()),
"n_neg": int((labels == 0).sum()),
"max_pool": {"auc": auc_max, "fixed_threshold": AMR_MAX_THR,
"precision": p_max, "recall": r_max, "f1": f1_max},
"mean_pool": {"auc": auc_mean, "fixed_threshold": AMR_MEAN_THR,
"precision": p_mean, "recall": r_mean, "f1": f1_mean},
}
summary_path.write_text(json.dumps(summary, indent=2))
print(f" saved summary to {summary_path}")
def main():
results_dir = Path("/home/ror25cal/MGnify/probes/results")
# Reads plot
print("=== Reads plot ===")
replot(
scores_jsonl=results_dir / "probe_on_reads_score_distributions.scores.jsonl",
png_path=results_dir / "probe_on_reads_score_distributions.png",
summary_path=results_dir / "probe_on_reads_score_distributions.summary.json",
pos_label="AMR positive reads",
neg_label="matched-neg reads",
suptitle=(
"v1 AMR linear probe applied to simulated MiSeq short reads (301 bp)\n"
"50 pos reads from MGYG000307615_01006 (abc-f MACROLIDE, 57% ID to ref) • "
"50 neg reads from matched MGYG000307615_00395\n"
"MAG was in v1 val split — never trained on. F1 at AMR-vs-MISC test threshold (deployed)."
),
x_axis_label="per-read probe logit (= w · h_pooled + b)",
)
# Transfer plot
print("\n=== Transfer plot (AMR vs VFDB virulence) ===")
replot(
scores_jsonl=results_dir / "transfer_amr_vs_virulence_score_distributions.scores.jsonl",
png_path=results_dir / "transfer_amr_vs_virulence_score_distributions.png",
summary_path=results_dir / "transfer_amr_vs_virulence_score_distributions.summary.json",
pos_label="AMR positive (CDS-only)",
neg_label="VFDB virulence (CDS-only)",
suptitle=(
"TRANSFER TEST: v1 AMR linear probe — AMR positives (336) vs VFDB virulence positives (336)\n"
"Pooled over CDS-only tokens for both classes (apples-to-apples).\n"
"F1 at AMR-vs-MISC test threshold (deployed) — not re-tuned on this transfer set."
),
)
if __name__ == "__main__":
main()
|