Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """Production-Ready Inference Module for DeepAMR. | |
| This module provides a clean API for making predictions with trained models, | |
| designed for integration with web frontends and APIs. | |
| Usage: | |
| from src.ml.inference import DeepAMRPredictor | |
| predictor = DeepAMRPredictor() | |
| results = predictor.predict(features) | |
| """ | |
| import json | |
| import logging | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple, Union | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import joblib | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ============================================================================= | |
| # Neural Network Architectures (must match training) | |
| # ============================================================================= | |
| class MultiHeadAttention(nn.Module): | |
| """Multi-head attention for feature importance weighting.""" | |
| def __init__(self, d_model: int, n_heads: int = 8): | |
| super().__init__() | |
| self.n_heads = n_heads | |
| self.d_head = d_model // n_heads | |
| self.q_linear = nn.Linear(d_model, d_model) | |
| self.k_linear = nn.Linear(d_model, d_model) | |
| self.v_linear = nn.Linear(d_model, d_model) | |
| self.out_linear = nn.Linear(d_model, d_model) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = x.unsqueeze(1) | |
| batch_size = x.size(0) | |
| q = self.q_linear(x).view(batch_size, -1, self.n_heads, self.d_head).transpose(1, 2) | |
| k = self.k_linear(x).view(batch_size, -1, self.n_heads, self.d_head).transpose(1, 2) | |
| v = self.v_linear(x).view(batch_size, -1, self.n_heads, self.d_head).transpose(1, 2) | |
| scores = torch.matmul(q, k.transpose(-2, -1)) / np.sqrt(self.d_head) | |
| attn = F.softmax(scores, dim=-1) | |
| context = torch.matmul(attn, v).transpose(1, 2).contiguous() | |
| context = context.view(batch_size, -1, self.n_heads * self.d_head) | |
| return self.out_linear(context).squeeze(1) | |
| class ResidualBlock(nn.Module): | |
| """Residual block with GELU activation.""" | |
| def __init__(self, dim: int, dropout: float = 0.2): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(dim, dim), | |
| nn.BatchNorm1d(dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(dim, dim), | |
| nn.BatchNorm1d(dim), | |
| ) | |
| self.gelu = nn.GELU() | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.gelu(x + self.net(x)) | |
| class AdvancedDeepAMR(nn.Module): | |
| """Advanced Deep Learning Model for AMR Prediction.""" | |
| def __init__( | |
| self, | |
| input_dim: int, | |
| output_dim: int, | |
| hidden_dim: int = 512, | |
| n_blocks: int = 4, | |
| ): | |
| super().__init__() | |
| self.embedding = nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim), | |
| nn.BatchNorm1d(hidden_dim), | |
| nn.GELU(), | |
| ) | |
| self.attention = MultiHeadAttention(hidden_dim, n_heads=8) | |
| self.res_blocks = nn.ModuleList([ | |
| ResidualBlock(hidden_dim) for _ in range(n_blocks) | |
| ]) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim // 2), | |
| nn.GELU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(hidden_dim // 2, output_dim), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.embedding(x) | |
| x = x + self.attention(x) | |
| for block in self.res_blocks: | |
| x = block(x) | |
| return self.classifier(x) | |
| # ============================================================================= | |
| # Production Predictor Class | |
| # ============================================================================= | |
| MODEL_VERSION = "1.0.0" | |
| class DeepAMRPredictor: | |
| """Production-ready predictor for AMR resistance. | |
| This class provides a clean interface for making predictions with | |
| trained DeepAMR models, suitable for frontend/API integration. | |
| Attributes: | |
| model: The loaded PyTorch model | |
| scaler: Feature scaler for preprocessing | |
| drug_classes: List of drug class names | |
| device: Computation device (cuda/mps/cpu) | |
| Example: | |
| >>> predictor = DeepAMRPredictor() | |
| >>> results = predictor.predict(kmer_features) | |
| >>> print(results['predictions']) | |
| {'aminoglycoside': True, 'beta-lactam': False, ...} | |
| """ | |
| # Default drug classes | |
| DEFAULT_DRUG_CLASSES = [ | |
| "aminoglycoside", | |
| "beta-lactam", | |
| "fosfomycin", | |
| "glycopeptide", | |
| "macrolide", | |
| "phenicol", | |
| "quinolone", | |
| "rifampicin", | |
| "sulfonamide", | |
| "tetracycline", | |
| "trimethoprim", | |
| ] | |
| def __init__( | |
| self, | |
| model_path: str = "models/advanced_deepamr_system.pt", | |
| device: str = "auto", | |
| ): | |
| """Initialize the predictor. | |
| Args: | |
| model_path: Path to the trained model checkpoint | |
| device: Device to use ('cuda', 'mps', 'cpu', or 'auto') | |
| """ | |
| self.model_path = Path(model_path) | |
| # Set device | |
| if device == "auto": | |
| if torch.cuda.is_available(): | |
| self.device = torch.device("cuda") | |
| elif torch.backends.mps.is_available(): | |
| self.device = torch.device("mps") | |
| else: | |
| self.device = torch.device("cpu") | |
| else: | |
| self.device = torch.device(device) | |
| logger.info(f"Using device: {self.device}") | |
| # Load model | |
| self._load_model() | |
| # Load optimal per-class thresholds if available | |
| self.optimal_thresholds: Optional[Dict] = None | |
| thresholds_path = Path("models/optimal_thresholds.json") | |
| if thresholds_path.exists(): | |
| with open(thresholds_path) as f: | |
| self.optimal_thresholds = json.load(f) | |
| logger.info("Loaded per-class optimal thresholds") | |
| # Load performance metadata if available | |
| self.performance_metrics: Optional[Dict] = None | |
| results_path = Path("models/advanced_system_results.json") | |
| if results_path.exists(): | |
| with open(results_path) as f: | |
| self.performance_metrics = json.load(f) | |
| logger.info("Loaded performance metrics") | |
| def _load_model(self): | |
| """Load the trained model and preprocessing components.""" | |
| if not self.model_path.exists(): | |
| raise FileNotFoundError(f"Model not found: {self.model_path}") | |
| logger.info(f"Loading model from {self.model_path}") | |
| checkpoint = torch.load(self.model_path, map_location=self.device, weights_only=False) | |
| # Extract components | |
| self.scaler = checkpoint.get("scaler") | |
| self.drug_classes = checkpoint.get("classes", self.DEFAULT_DRUG_CLASSES) | |
| # Determine model dimensions from checkpoint | |
| state_dict = checkpoint.get("model_state_dict", checkpoint) | |
| # Infer dimensions from state dict | |
| input_dim = state_dict["embedding.0.weight"].shape[1] | |
| output_dim = state_dict["classifier.3.weight"].shape[0] | |
| hidden_dim = state_dict["embedding.0.weight"].shape[0] | |
| # Create model architecture | |
| self.model = AdvancedDeepAMR( | |
| input_dim=input_dim, | |
| output_dim=output_dim, | |
| hidden_dim=hidden_dim, | |
| ).to(self.device) | |
| # Load weights | |
| self.model.load_state_dict(state_dict) | |
| self.model.eval() | |
| logger.info(f"Model loaded successfully. Drug classes: {len(self.drug_classes)}") | |
| def predict( | |
| self, | |
| features: Union[np.ndarray, List[List[float]]], | |
| threshold: float = 0.5, | |
| return_probabilities: bool = True, | |
| ) -> Dict: | |
| """Make AMR resistance predictions. | |
| Args: | |
| features: Input features (k-mer frequencies). Shape: (n_samples, n_features) | |
| or (n_features,) for single sample | |
| threshold: Probability threshold for positive prediction (default: 0.5) | |
| return_probabilities: Whether to include probability scores | |
| Returns: | |
| Dictionary containing: | |
| - predictions: Dict mapping drug class to resistance status | |
| - probabilities: Dict mapping drug class to probability (if requested) | |
| - resistant_count: Number of drug classes with predicted resistance | |
| - susceptible_count: Number of drug classes predicted susceptible | |
| """ | |
| # Convert to numpy if needed | |
| if isinstance(features, list): | |
| features = np.array(features) | |
| # Handle single sample | |
| if features.ndim == 1: | |
| features = features.reshape(1, -1) | |
| # Scale features | |
| if self.scaler is not None: | |
| features = self.scaler.transform(features) | |
| # Convert to tensor | |
| X = torch.FloatTensor(features).to(self.device) | |
| # Predict | |
| with torch.no_grad(): | |
| logits = self.model(X) | |
| probabilities = torch.sigmoid(logits).cpu().numpy() | |
| # Process results | |
| results = [] | |
| for i in range(len(probabilities)): | |
| probs = probabilities[i] | |
| # Use per-class optimal thresholds if available and default threshold requested | |
| if threshold == 0.5 and self.optimal_thresholds: | |
| preds = np.array([ | |
| int(probs[j] > self.optimal_thresholds.get(drug, {}).get("threshold", 0.5)) | |
| for j, drug in enumerate(self.drug_classes) | |
| ]) | |
| else: | |
| preds = (probs > threshold).astype(int) | |
| result = { | |
| "predictions": { | |
| drug: bool(preds[j]) for j, drug in enumerate(self.drug_classes) | |
| }, | |
| "resistant_count": int(preds.sum()), | |
| "susceptible_count": int(len(self.drug_classes) - preds.sum()), | |
| } | |
| if return_probabilities: | |
| result["probabilities"] = { | |
| drug: float(probs[j]) for j, drug in enumerate(self.drug_classes) | |
| } | |
| results.append(result) | |
| # Return single result if single input | |
| return results[0] if len(results) == 1 else results | |
| def predict_batch( | |
| self, | |
| features_list: List[np.ndarray], | |
| threshold: float = 0.5, | |
| batch_size: int = 32, | |
| ) -> List[Dict]: | |
| """Make predictions on a batch of samples efficiently. | |
| Args: | |
| features_list: List of feature arrays | |
| threshold: Probability threshold for positive prediction | |
| batch_size: Processing batch size | |
| Returns: | |
| List of prediction dictionaries | |
| """ | |
| all_results = [] | |
| # Process in batches | |
| for i in range(0, len(features_list), batch_size): | |
| batch = np.array(features_list[i:i + batch_size]) | |
| batch_results = self.predict(batch, threshold=threshold) | |
| if isinstance(batch_results, dict): | |
| all_results.append(batch_results) | |
| else: | |
| all_results.extend(batch_results) | |
| return all_results | |
| def get_resistance_summary(self, predictions: Dict) -> Dict: | |
| """Generate a human-readable summary of predictions. | |
| Args: | |
| predictions: Output from predict() | |
| Returns: | |
| Summary dictionary with formatted results | |
| """ | |
| resistant_drugs = [ | |
| drug for drug, status in predictions["predictions"].items() if status | |
| ] | |
| susceptible_drugs = [ | |
| drug for drug, status in predictions["predictions"].items() if not status | |
| ] | |
| # Risk level assessment | |
| if predictions["resistant_count"] >= 5: | |
| risk_level = "HIGH" | |
| risk_description = "Multi-drug resistant (MDR) - Requires specialist consultation" | |
| elif predictions["resistant_count"] >= 3: | |
| risk_level = "MODERATE" | |
| risk_description = "Multiple resistance detected - Consider alternative treatments" | |
| elif predictions["resistant_count"] >= 1: | |
| risk_level = "LOW" | |
| risk_description = "Limited resistance - Standard alternatives available" | |
| else: | |
| risk_level = "MINIMAL" | |
| risk_description = "No predicted resistance - Standard treatment likely effective" | |
| summary = { | |
| "risk_level": risk_level, | |
| "risk_description": risk_description, | |
| "resistant_drugs": resistant_drugs, | |
| "susceptible_drugs": susceptible_drugs, | |
| "total_tested": len(self.drug_classes), | |
| } | |
| # Add probability-based confidence if available | |
| if "probabilities" in predictions: | |
| probs = predictions["probabilities"] | |
| high_confidence = [ | |
| drug for drug, prob in probs.items() | |
| if prob > 0.8 or prob < 0.2 | |
| ] | |
| summary["high_confidence_predictions"] = len(high_confidence) | |
| summary["average_confidence"] = np.mean([ | |
| max(p, 1-p) for p in probs.values() | |
| ]) | |
| return summary | |
| def model_info(self) -> Dict: | |
| """Get information about the loaded model.""" | |
| info = { | |
| "model_path": str(self.model_path), | |
| "device": str(self.device), | |
| "drug_classes": self.drug_classes, | |
| "n_classes": len(self.drug_classes), | |
| "has_scaler": self.scaler is not None, | |
| "model_version": MODEL_VERSION, | |
| "has_optimal_thresholds": self.optimal_thresholds is not None, | |
| } | |
| if self.performance_metrics: | |
| info["performance"] = self.performance_metrics | |
| return info | |
| # ============================================================================= | |
| # Sklearn Model Predictor (for ensemble/traditional ML models) | |
| # ============================================================================= | |
| class SklearnAMRPredictor: | |
| """Predictor for sklearn-based AMR models.""" | |
| def __init__( | |
| self, | |
| model_path: str = "models/optimized/optimized_ensemble_stacking.joblib", | |
| ): | |
| """Initialize sklearn predictor. | |
| Args: | |
| model_path: Path to joblib model file | |
| """ | |
| self.model_path = Path(model_path) | |
| self._load_model() | |
| def _load_model(self): | |
| """Load the sklearn model and components.""" | |
| if not self.model_path.exists(): | |
| raise FileNotFoundError(f"Model not found: {self.model_path}") | |
| data = joblib.load(self.model_path) | |
| self.model = data["model"] | |
| self.scaler = data.get("scaler") | |
| self.feature_selector = data.get("feature_selector") | |
| self.drug_classes = data.get("classes", DeepAMRPredictor.DEFAULT_DRUG_CLASSES) | |
| logger.info(f"Sklearn model loaded from {self.model_path}") | |
| def predict( | |
| self, | |
| features: np.ndarray, | |
| return_probabilities: bool = True, | |
| ) -> Dict: | |
| """Make predictions with sklearn model.""" | |
| if features.ndim == 1: | |
| features = features.reshape(1, -1) | |
| # Preprocess | |
| if self.scaler is not None: | |
| features = self.scaler.transform(features) | |
| if self.feature_selector is not None: | |
| if isinstance(self.feature_selector, np.ndarray): | |
| features = features[:, self.feature_selector] | |
| else: | |
| features = self.feature_selector.transform(features) | |
| # Predict | |
| predictions = self.model.predict(features) | |
| results = [] | |
| for i in range(len(predictions)): | |
| preds = predictions[i] | |
| result = { | |
| "predictions": { | |
| drug: bool(preds[j]) for j, drug in enumerate(self.drug_classes) | |
| }, | |
| "resistant_count": int(preds.sum()), | |
| "susceptible_count": int(len(self.drug_classes) - preds.sum()), | |
| } | |
| if return_probabilities and hasattr(self.model, "predict_proba"): | |
| try: | |
| probs = self.model.predict_proba(features[i:i+1])[0] | |
| result["probabilities"] = { | |
| drug: float(probs[j]) for j, drug in enumerate(self.drug_classes) | |
| } | |
| except Exception: | |
| pass | |
| results.append(result) | |
| return results[0] if len(results) == 1 else results | |
| def model_info(self) -> Dict: | |
| return { | |
| "model_path": str(self.model_path), | |
| "device": "cpu", | |
| "drug_classes": self.drug_classes, | |
| "n_classes": len(self.drug_classes), | |
| "has_scaler": self.scaler is not None, | |
| } | |
| # ============================================================================= | |
| # Unified Predictor Factory | |
| # ============================================================================= | |
| def get_predictor( | |
| model_type: str = "deep_learning", | |
| model_path: Optional[str] = None, | |
| ) -> Union[DeepAMRPredictor, SklearnAMRPredictor]: | |
| """Factory function to get the appropriate predictor. | |
| Args: | |
| model_type: Type of model ('deep_learning', 'sklearn', 'ensemble') | |
| model_path: Optional custom model path | |
| Returns: | |
| Configured predictor instance | |
| """ | |
| if model_type == "deep_learning": | |
| path = model_path or "models/advanced_deepamr_system.pt" | |
| return DeepAMRPredictor(path) | |
| elif model_type in ["sklearn", "ensemble"]: | |
| path = model_path or "models/optimized/optimized_ensemble_stacking.joblib" | |
| return SklearnAMRPredictor(path) | |
| else: | |
| raise ValueError(f"Unknown model type: {model_type}") | |
| # ============================================================================= | |
| # CLI for testing | |
| # ============================================================================= | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="DeepAMR Inference") | |
| parser.add_argument("--model", default="deep_learning", choices=["deep_learning", "sklearn"]) | |
| parser.add_argument("--test", action="store_true", help="Run test prediction") | |
| args = parser.parse_args() | |
| predictor = get_predictor(args.model) | |
| print(f"Model info: {predictor.model_info}") | |
| if args.test: | |
| # Load test data for demo | |
| import numpy as np | |
| X_test = np.load("data/processed/ncbi/ncbi_amr_X_test.npy") | |
| # Predict on first sample | |
| result = predictor.predict(X_test[0]) | |
| print("\nTest prediction:") | |
| print(f"Predictions: {result['predictions']}") | |
| print(f"Resistant: {result['resistant_count']}, Susceptible: {result['susceptible_count']}") | |
| if "probabilities" in result: | |
| print("\nProbabilities:") | |
| for drug, prob in result["probabilities"].items(): | |
| status = "R" if prob > 0.5 else "S" | |
| print(f" {drug}: {prob:.3f} ({status})") | |