File size: 4,518 Bytes
eb69de4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""Generate score-distribution plot for the v1 AMR probe applied to 100 simulated
short reads. Loads from the cached probe_on_reads_test.json (no Modal needed).

Output mirrors the AMR distribution plot: PNG + scores.jsonl + summary.json.

Usage:
  python probes/make_reads_score_plot.py
"""
from __future__ import annotations

import json
from pathlib import Path

import numpy as np
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, precision_recall_curve


SRC = Path("/home/ror25cal/MGnify/probes/results/probe_on_reads_test.json")
OUT_PNG = Path("/home/ror25cal/MGnify/probes/results/probe_on_reads_score_distributions.png")


def stats(scores, labels):
    auc = roc_auc_score(labels, scores)
    prec, rec, thr = precision_recall_curve(labels, scores)
    f1 = 2 * prec * rec / np.maximum(prec + rec, 1e-9)
    idx = int(np.argmax(f1))
    best_thr = float(thr[min(idx, len(thr) - 1)])
    return float(auc), best_thr, float(f1[idx])


def main():
    raw = json.loads(SRC.read_text())
    records = raw["results"]
    labels = np.array([r["label"] for r in records])
    max_logits = np.array([r["max_logit"] for r in records])
    mean_logits = np.array([r["mean_logit"] for r in records])

    auc_max, t_max, f1_max = stats(max_logits, labels)
    auc_mean, t_mean, f1_mean = stats(mean_logits, labels)
    print(f"max-pool:  AUC={auc_max:.4f}  best-F1 thr={t_max:.3f}  F1={f1_max:.3f}")
    print(f"mean-pool: AUC={auc_mean:.4f}  best-F1 thr={t_mean:.3f}  F1={f1_mean:.3f}")

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    for ax, scores, name, auc, best_thr, f1 in [
        (axes[0], max_logits, "max-pool", auc_max, t_max, f1_max),
        (axes[1], mean_logits, "mean-pool", auc_mean, t_mean, f1_mean),
    ]:
        pos = scores[labels == 1]
        neg = scores[labels == 0]
        bins = np.linspace(scores.min() - 0.5, scores.max() + 0.5, 30)
        ax.hist(neg, bins=bins, alpha=0.55, label=f"matched-neg (n={len(neg)})", color="tab:red", density=True)
        ax.hist(pos, bins=bins, alpha=0.55, label=f"AMR positive (n={len(pos)})", color="tab:green", density=True)
        rug_y = ax.get_ylim()[0] - 0.02 * (ax.get_ylim()[1] - ax.get_ylim()[0])
        ax.scatter(neg, np.full_like(neg, rug_y), marker="|", s=80, color="tab:red", alpha=0.6)
        ax.scatter(pos, np.full_like(pos, rug_y * 1.5), marker="|", s=80, color="tab:green", alpha=0.6)
        ax.axvline(0, color="grey", lw=1, ls="--", label="default boundary (logit=0)")
        ax.axvline(best_thr, color="black", lw=1, ls=":", label=f"best-F1 boundary ({best_thr:.2f})")
        ax.set_xlabel("per-read probe logit (= w · h_pooled + b)", fontsize=10)
        ax.set_ylabel("density", fontsize=10)
        ax.set_title(f"{name}  •  AUC={auc:.4f}  •  best-F1={f1:.3f}", fontsize=11)
        ax.legend(fontsize=8, loc="upper left")
        ax.grid(True, alpha=0.2)

    fig.suptitle(
        "v1 AMR linear probe applied to simulated MiSeq short reads (301 bp)\n"
        "50 pos reads from MGYG000307615_01006 (abc-f MACROLIDE, 57% ID to ref)  •  "
        "50 neg reads from matched MGYG000307615_00395\n"
        "MAG was in v1 val split — never trained on. Pooling over all ~301 tokens of each read.",
        fontsize=10,
    )
    fig.tight_layout()
    OUT_PNG.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(OUT_PNG, dpi=120, bbox_inches="tight")
    plt.close(fig)
    print(f"saved {OUT_PNG.stat().st_size/1024:.1f} KB to {OUT_PNG}")

    # Raw scores
    out_jsonl = OUT_PNG.with_suffix(".scores.jsonl")
    with open(out_jsonl, "w") as f:
        for r in records:
            f.write(json.dumps({
                "read_id": r["read_id"],
                "label": r["label"],
                "seq_len": r["seq_len"],
                "max_logit": r["max_logit"],
                "mean_logit": r["mean_logit"],
                "median_logit": r.get("median_logit"),
            }) + "\n")
    print(f"saved {len(records)} per-read scores to {out_jsonl}")

    out_summary = OUT_PNG.with_suffix(".summary.json")
    out_summary.write_text(json.dumps({
        "n_pos": int((labels == 1).sum()),
        "n_neg": int((labels == 0).sum()),
        "max_pool":  {"auc": auc_max,  "best_f1_threshold": t_max,  "best_f1": f1_max},
        "mean_pool": {"auc": auc_mean, "best_f1_threshold": t_mean, "best_f1": f1_mean},
    }, indent=2))
    print(f"saved aggregate summary to {out_summary}")


if __name__ == "__main__":
    main()