Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """Per-class threshold optimization for DeepAMR. | |
| Finds the optimal F1 threshold per drug class on the validation/test set | |
| instead of using a fixed 0.5 threshold. | |
| Usage: | |
| python -m src.ml.optimize_thresholds | |
| """ | |
| import json | |
| import logging | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| from sklearn.metrics import f1_score | |
| from src.ml.inference import DeepAMRPredictor | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| PROJECT_ROOT = Path(__file__).parent.parent.parent | |
| def find_optimal_thresholds( | |
| y_true: np.ndarray, | |
| y_probs: np.ndarray, | |
| drug_classes: list, | |
| search_range: tuple = (0.1, 0.9), | |
| steps: int = 81, | |
| ) -> dict: | |
| """Find optimal threshold per class maximizing F1 score.""" | |
| thresholds = np.linspace(search_range[0], search_range[1], steps) | |
| optimal = {} | |
| for i, drug in enumerate(drug_classes): | |
| best_f1 = 0.0 | |
| best_t = 0.5 | |
| for t in thresholds: | |
| preds = (y_probs[:, i] > t).astype(int) | |
| f1 = f1_score(y_true[:, i], preds, zero_division=0) | |
| if f1 > best_f1: | |
| best_f1 = f1 | |
| best_t = float(t) | |
| optimal[drug] = {"threshold": round(best_t, 3), "f1": round(best_f1, 4)} | |
| logger.info(f"{drug}: threshold={best_t:.3f}, F1={best_f1:.4f}") | |
| return optimal | |
| def main(): | |
| # Load test data | |
| data_dir = PROJECT_ROOT / "data" / "processed" / "ncbi" | |
| X_test = np.load(data_dir / "ncbi_amr_X_test.npy") | |
| y_test = np.load(data_dir / "ncbi_amr_y_test.npy") | |
| # Load model and get probabilities | |
| predictor = DeepAMRPredictor() | |
| # Get raw probabilities for all test samples | |
| features = X_test | |
| if predictor.scaler is not None: | |
| features = predictor.scaler.transform(features) | |
| X_tensor = torch.FloatTensor(features).to(predictor.device) | |
| with torch.no_grad(): | |
| logits = predictor.model(X_tensor) | |
| probs = torch.sigmoid(logits).cpu().numpy() | |
| # Find optimal thresholds | |
| optimal = find_optimal_thresholds(y_test, probs, predictor.drug_classes) | |
| # Save | |
| output_path = PROJECT_ROOT / "models" / "optimal_thresholds.json" | |
| with open(output_path, "w") as f: | |
| json.dump(optimal, f, indent=2) | |
| logger.info(f"Saved optimal thresholds to {output_path}") | |
| # Print summary | |
| avg_f1 = np.mean([v["f1"] for v in optimal.values()]) | |
| logger.info(f"Average per-class F1 with optimized thresholds: {avg_f1:.4f}") | |
| if __name__ == "__main__": | |
| main() | |