File size: 2,524 Bytes
3255634
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#!/usr/bin/env python3
"""Per-class threshold optimization for DeepAMR.

Finds the optimal F1 threshold per drug class on the validation/test set
instead of using a fixed 0.5 threshold.

Usage:
    python -m src.ml.optimize_thresholds
"""

import json
import logging
from pathlib import Path

import numpy as np
import torch
from sklearn.metrics import f1_score

from src.ml.inference import DeepAMRPredictor

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

PROJECT_ROOT = Path(__file__).parent.parent.parent


def find_optimal_thresholds(
    y_true: np.ndarray,
    y_probs: np.ndarray,
    drug_classes: list,
    search_range: tuple = (0.1, 0.9),
    steps: int = 81,
) -> dict:
    """Find optimal threshold per class maximizing F1 score."""
    thresholds = np.linspace(search_range[0], search_range[1], steps)
    optimal = {}

    for i, drug in enumerate(drug_classes):
        best_f1 = 0.0
        best_t = 0.5
        for t in thresholds:
            preds = (y_probs[:, i] > t).astype(int)
            f1 = f1_score(y_true[:, i], preds, zero_division=0)
            if f1 > best_f1:
                best_f1 = f1
                best_t = float(t)
        optimal[drug] = {"threshold": round(best_t, 3), "f1": round(best_f1, 4)}
        logger.info(f"{drug}: threshold={best_t:.3f}, F1={best_f1:.4f}")

    return optimal


def main():
    # Load test data
    data_dir = PROJECT_ROOT / "data" / "processed" / "ncbi"
    X_test = np.load(data_dir / "ncbi_amr_X_test.npy")
    y_test = np.load(data_dir / "ncbi_amr_y_test.npy")

    # Load model and get probabilities
    predictor = DeepAMRPredictor()

    # Get raw probabilities for all test samples
    features = X_test
    if predictor.scaler is not None:
        features = predictor.scaler.transform(features)

    X_tensor = torch.FloatTensor(features).to(predictor.device)
    with torch.no_grad():
        logits = predictor.model(X_tensor)
        probs = torch.sigmoid(logits).cpu().numpy()

    # Find optimal thresholds
    optimal = find_optimal_thresholds(y_test, probs, predictor.drug_classes)

    # Save
    output_path = PROJECT_ROOT / "models" / "optimal_thresholds.json"
    with open(output_path, "w") as f:
        json.dump(optimal, f, indent=2)
    logger.info(f"Saved optimal thresholds to {output_path}")

    # Print summary
    avg_f1 = np.mean([v["f1"] for v in optimal.values()])
    logger.info(f"Average per-class F1 with optimized thresholds: {avg_f1:.4f}")


if __name__ == "__main__":
    main()