File size: 6,365 Bytes
eb69de4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""Re-render the reads and transfer score-distribution plots using the
v1 AMR probe's TEST-tuned thresholds (from amr_binary_v1) — i.e., the
deployed thresholds — instead of newly-tuned thresholds on each plot's data.

Local-only; reads cached scores.jsonl, recomputes F1 at fixed thresholds,
re-renders PNG, updates summary.json.

Usage:
  python probes/refit_thresholds.py
"""
from __future__ import annotations

import json
from pathlib import Path

import numpy as np
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score


# Thresholds from amr_binary_v1_score_distributions.summary.json (test set)
AMR_MAX_THR = 6.583248138427734
AMR_MEAN_THR = -0.6747861504554749


def f1_at_threshold(scores, labels, thr):
    pred = (scores > thr).astype(int)
    tp = int(((pred == 1) & (labels == 1)).sum())
    fp = int(((pred == 1) & (labels == 0)).sum())
    fn = int(((pred == 0) & (labels == 1)).sum())
    if tp + fp == 0 or tp + fn == 0:
        return 0.0, 0.0, 0.0
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    f1 = 2 * precision * recall / max(precision + recall, 1e-9)
    return float(precision), float(recall), float(f1)


def replot(scores_jsonl: Path, png_path: Path, summary_path: Path,
           pos_label: str, neg_label: str, suptitle: str,
           x_axis_label: str = "per-region probe logit (= w · h_pooled + b)"):
    rows = [json.loads(l) for l in scores_jsonl.read_text().splitlines() if l.strip()]
    labels = np.array([r["label"] for r in rows])
    max_logits = np.array([r["max_logit"] for r in rows])
    mean_logits = np.array([r["mean_logit"] for r in rows])

    auc_max = float(roc_auc_score(labels, max_logits))
    auc_mean = float(roc_auc_score(labels, mean_logits))

    p_max, r_max, f1_max = f1_at_threshold(max_logits, labels, AMR_MAX_THR)
    p_mean, r_mean, f1_mean = f1_at_threshold(mean_logits, labels, AMR_MEAN_THR)

    print(f"  max-pool:  AUC={auc_max:.4f}  fixed thr={AMR_MAX_THR:.3f}  "
          f"P={p_max:.3f} R={r_max:.3f} F1={f1_max:.3f}")
    print(f"  mean-pool: AUC={auc_mean:.4f}  fixed thr={AMR_MEAN_THR:.3f}  "
          f"P={p_mean:.3f} R={r_mean:.3f} F1={f1_mean:.3f}")

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    for ax, scores, name, auc, fixed_thr, p, r, f1 in [
        (axes[0], max_logits,  "max-pool",  auc_max,  AMR_MAX_THR,  p_max,  r_max,  f1_max),
        (axes[1], mean_logits, "mean-pool", auc_mean, AMR_MEAN_THR, p_mean, r_mean, f1_mean),
    ]:
        pos = scores[labels == 1]
        neg = scores[labels == 0]
        bins = np.linspace(scores.min() - 0.5, scores.max() + 0.5, 50)
        ax.hist(neg, bins=bins, alpha=0.55, label=f"{neg_label} (n={len(neg)})", color="tab:red", density=True)
        ax.hist(pos, bins=bins, alpha=0.55, label=f"{pos_label} (n={len(pos)})", color="tab:green", density=True)
        rug_y = ax.get_ylim()[0] - 0.02 * (ax.get_ylim()[1] - ax.get_ylim()[0])
        ax.scatter(neg, np.full_like(neg, rug_y), marker="|", s=60, color="tab:red", alpha=0.5)
        ax.scatter(pos, np.full_like(pos, rug_y * 1.5), marker="|", s=60, color="tab:green", alpha=0.5)
        ax.axvline(0, color="grey", lw=1, ls="--", label="default boundary (logit=0)")
        ax.axvline(fixed_thr, color="black", lw=1, ls=":",
                   label=f"AMR-vs-MISC test boundary ({fixed_thr:.2f})")
        ax.set_xlabel(x_axis_label, fontsize=10)
        ax.set_ylabel("density", fontsize=10)
        ax.set_title(f"{name}  •  AUC={auc:.4f}  •  F1@fixed_thr={f1:.3f}  "
                     f"(P={p:.2f}, R={r:.2f})", fontsize=10)
        ax.legend(fontsize=8, loc="upper left")
        ax.grid(True, alpha=0.2)
    fig.suptitle(suptitle, fontsize=10)
    fig.tight_layout()
    png_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(png_path, dpi=120, bbox_inches="tight")
    plt.close(fig)
    print(f"  saved {png_path.stat().st_size/1024:.1f} KB to {png_path}")

    summary = {
        "fixed_thresholds_source": "amr_binary_v1_score_distributions.summary.json (TEST split of AMR-vs-MISC, the v1 probe's deployed thresholds)",
        "n_pos": int((labels == 1).sum()),
        "n_neg": int((labels == 0).sum()),
        "max_pool":  {"auc": auc_max,  "fixed_threshold": AMR_MAX_THR,
                      "precision": p_max,  "recall": r_max,  "f1": f1_max},
        "mean_pool": {"auc": auc_mean, "fixed_threshold": AMR_MEAN_THR,
                      "precision": p_mean, "recall": r_mean, "f1": f1_mean},
    }
    summary_path.write_text(json.dumps(summary, indent=2))
    print(f"  saved summary to {summary_path}")


def main():
    results_dir = Path("/home/ror25cal/MGnify/probes/results")

    # Reads plot
    print("=== Reads plot ===")
    replot(
        scores_jsonl=results_dir / "probe_on_reads_score_distributions.scores.jsonl",
        png_path=results_dir / "probe_on_reads_score_distributions.png",
        summary_path=results_dir / "probe_on_reads_score_distributions.summary.json",
        pos_label="AMR positive reads",
        neg_label="matched-neg reads",
        suptitle=(
            "v1 AMR linear probe applied to simulated MiSeq short reads (301 bp)\n"
            "50 pos reads from MGYG000307615_01006 (abc-f MACROLIDE, 57% ID to ref)  •  "
            "50 neg reads from matched MGYG000307615_00395\n"
            "MAG was in v1 val split — never trained on. F1 at AMR-vs-MISC test threshold (deployed)."
        ),
        x_axis_label="per-read probe logit (= w · h_pooled + b)",
    )

    # Transfer plot
    print("\n=== Transfer plot (AMR vs VFDB virulence) ===")
    replot(
        scores_jsonl=results_dir / "transfer_amr_vs_virulence_score_distributions.scores.jsonl",
        png_path=results_dir / "transfer_amr_vs_virulence_score_distributions.png",
        summary_path=results_dir / "transfer_amr_vs_virulence_score_distributions.summary.json",
        pos_label="AMR positive (CDS-only)",
        neg_label="VFDB virulence (CDS-only)",
        suptitle=(
            "TRANSFER TEST: v1 AMR linear probe — AMR positives (336) vs VFDB virulence positives (336)\n"
            "Pooled over CDS-only tokens for both classes (apples-to-apples).\n"
            "F1 at AMR-vs-MISC test threshold (deployed) — not re-tuned on this transfer set."
        ),
    )


if __name__ == "__main__":
    main()