#!/usr/bin/env python3 """Production-Ready Inference Module for DeepAMR. This module provides a clean API for making predictions with trained models, designed for integration with web frontends and APIs. Usage: from src.ml.inference import DeepAMRPredictor predictor = DeepAMRPredictor() results = predictor.predict(features) """ import json import logging from pathlib import Path from typing import Dict, List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import joblib logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ============================================================================= # Neural Network Architectures (must match training) # ============================================================================= class MultiHeadAttention(nn.Module): """Multi-head attention for feature importance weighting.""" def __init__(self, d_model: int, n_heads: int = 8): super().__init__() self.n_heads = n_heads self.d_head = d_model // n_heads self.q_linear = nn.Linear(d_model, d_model) self.k_linear = nn.Linear(d_model, d_model) self.v_linear = nn.Linear(d_model, d_model) self.out_linear = nn.Linear(d_model, d_model) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.unsqueeze(1) batch_size = x.size(0) q = self.q_linear(x).view(batch_size, -1, self.n_heads, self.d_head).transpose(1, 2) k = self.k_linear(x).view(batch_size, -1, self.n_heads, self.d_head).transpose(1, 2) v = self.v_linear(x).view(batch_size, -1, self.n_heads, self.d_head).transpose(1, 2) scores = torch.matmul(q, k.transpose(-2, -1)) / np.sqrt(self.d_head) attn = F.softmax(scores, dim=-1) context = torch.matmul(attn, v).transpose(1, 2).contiguous() context = context.view(batch_size, -1, self.n_heads * self.d_head) return self.out_linear(context).squeeze(1) class ResidualBlock(nn.Module): """Residual block with GELU activation.""" def __init__(self, dim: int, dropout: float = 0.2): super().__init__() self.net = nn.Sequential( nn.Linear(dim, dim), nn.BatchNorm1d(dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim, dim), nn.BatchNorm1d(dim), ) self.gelu = nn.GELU() def forward(self, x: torch.Tensor) -> torch.Tensor: return self.gelu(x + self.net(x)) class AdvancedDeepAMR(nn.Module): """Advanced Deep Learning Model for AMR Prediction.""" def __init__( self, input_dim: int, output_dim: int, hidden_dim: int = 512, n_blocks: int = 4, ): super().__init__() self.embedding = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.GELU(), ) self.attention = MultiHeadAttention(hidden_dim, n_heads=8) self.res_blocks = nn.ModuleList([ ResidualBlock(hidden_dim) for _ in range(n_blocks) ]) self.classifier = nn.Sequential( nn.Linear(hidden_dim, hidden_dim // 2), nn.GELU(), nn.Dropout(0.3), nn.Linear(hidden_dim // 2, output_dim), ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.embedding(x) x = x + self.attention(x) for block in self.res_blocks: x = block(x) return self.classifier(x) # ============================================================================= # Production Predictor Class # ============================================================================= MODEL_VERSION = "1.0.0" class DeepAMRPredictor: """Production-ready predictor for AMR resistance. This class provides a clean interface for making predictions with trained DeepAMR models, suitable for frontend/API integration. Attributes: model: The loaded PyTorch model scaler: Feature scaler for preprocessing drug_classes: List of drug class names device: Computation device (cuda/mps/cpu) Example: >>> predictor = DeepAMRPredictor() >>> results = predictor.predict(kmer_features) >>> print(results['predictions']) {'aminoglycoside': True, 'beta-lactam': False, ...} """ # Default drug classes DEFAULT_DRUG_CLASSES = [ "aminoglycoside", "beta-lactam", "fosfomycin", "glycopeptide", "macrolide", "phenicol", "quinolone", "rifampicin", "sulfonamide", "tetracycline", "trimethoprim", ] def __init__( self, model_path: str = "models/advanced_deepamr_system.pt", device: str = "auto", ): """Initialize the predictor. Args: model_path: Path to the trained model checkpoint device: Device to use ('cuda', 'mps', 'cpu', or 'auto') """ self.model_path = Path(model_path) # Set device if device == "auto": if torch.cuda.is_available(): self.device = torch.device("cuda") elif torch.backends.mps.is_available(): self.device = torch.device("mps") else: self.device = torch.device("cpu") else: self.device = torch.device(device) logger.info(f"Using device: {self.device}") # Load model self._load_model() # Load optimal per-class thresholds if available self.optimal_thresholds: Optional[Dict] = None thresholds_path = Path("models/optimal_thresholds.json") if thresholds_path.exists(): with open(thresholds_path) as f: self.optimal_thresholds = json.load(f) logger.info("Loaded per-class optimal thresholds") # Load performance metadata if available self.performance_metrics: Optional[Dict] = None results_path = Path("models/advanced_system_results.json") if results_path.exists(): with open(results_path) as f: self.performance_metrics = json.load(f) logger.info("Loaded performance metrics") def _load_model(self): """Load the trained model and preprocessing components.""" if not self.model_path.exists(): raise FileNotFoundError(f"Model not found: {self.model_path}") logger.info(f"Loading model from {self.model_path}") checkpoint = torch.load(self.model_path, map_location=self.device, weights_only=False) # Extract components self.scaler = checkpoint.get("scaler") self.drug_classes = checkpoint.get("classes", self.DEFAULT_DRUG_CLASSES) # Determine model dimensions from checkpoint state_dict = checkpoint.get("model_state_dict", checkpoint) # Infer dimensions from state dict input_dim = state_dict["embedding.0.weight"].shape[1] output_dim = state_dict["classifier.3.weight"].shape[0] hidden_dim = state_dict["embedding.0.weight"].shape[0] # Create model architecture self.model = AdvancedDeepAMR( input_dim=input_dim, output_dim=output_dim, hidden_dim=hidden_dim, ).to(self.device) # Load weights self.model.load_state_dict(state_dict) self.model.eval() logger.info(f"Model loaded successfully. Drug classes: {len(self.drug_classes)}") def predict( self, features: Union[np.ndarray, List[List[float]]], threshold: float = 0.5, return_probabilities: bool = True, ) -> Dict: """Make AMR resistance predictions. Args: features: Input features (k-mer frequencies). Shape: (n_samples, n_features) or (n_features,) for single sample threshold: Probability threshold for positive prediction (default: 0.5) return_probabilities: Whether to include probability scores Returns: Dictionary containing: - predictions: Dict mapping drug class to resistance status - probabilities: Dict mapping drug class to probability (if requested) - resistant_count: Number of drug classes with predicted resistance - susceptible_count: Number of drug classes predicted susceptible """ # Convert to numpy if needed if isinstance(features, list): features = np.array(features) # Handle single sample if features.ndim == 1: features = features.reshape(1, -1) # Scale features if self.scaler is not None: features = self.scaler.transform(features) # Convert to tensor X = torch.FloatTensor(features).to(self.device) # Predict with torch.no_grad(): logits = self.model(X) probabilities = torch.sigmoid(logits).cpu().numpy() # Process results results = [] for i in range(len(probabilities)): probs = probabilities[i] # Use per-class optimal thresholds if available and default threshold requested if threshold == 0.5 and self.optimal_thresholds: preds = np.array([ int(probs[j] > self.optimal_thresholds.get(drug, {}).get("threshold", 0.5)) for j, drug in enumerate(self.drug_classes) ]) else: preds = (probs > threshold).astype(int) result = { "predictions": { drug: bool(preds[j]) for j, drug in enumerate(self.drug_classes) }, "resistant_count": int(preds.sum()), "susceptible_count": int(len(self.drug_classes) - preds.sum()), } if return_probabilities: result["probabilities"] = { drug: float(probs[j]) for j, drug in enumerate(self.drug_classes) } results.append(result) # Return single result if single input return results[0] if len(results) == 1 else results def predict_batch( self, features_list: List[np.ndarray], threshold: float = 0.5, batch_size: int = 32, ) -> List[Dict]: """Make predictions on a batch of samples efficiently. Args: features_list: List of feature arrays threshold: Probability threshold for positive prediction batch_size: Processing batch size Returns: List of prediction dictionaries """ all_results = [] # Process in batches for i in range(0, len(features_list), batch_size): batch = np.array(features_list[i:i + batch_size]) batch_results = self.predict(batch, threshold=threshold) if isinstance(batch_results, dict): all_results.append(batch_results) else: all_results.extend(batch_results) return all_results def get_resistance_summary(self, predictions: Dict) -> Dict: """Generate a human-readable summary of predictions. Args: predictions: Output from predict() Returns: Summary dictionary with formatted results """ resistant_drugs = [ drug for drug, status in predictions["predictions"].items() if status ] susceptible_drugs = [ drug for drug, status in predictions["predictions"].items() if not status ] # Risk level assessment if predictions["resistant_count"] >= 5: risk_level = "HIGH" risk_description = "Multi-drug resistant (MDR) - Requires specialist consultation" elif predictions["resistant_count"] >= 3: risk_level = "MODERATE" risk_description = "Multiple resistance detected - Consider alternative treatments" elif predictions["resistant_count"] >= 1: risk_level = "LOW" risk_description = "Limited resistance - Standard alternatives available" else: risk_level = "MINIMAL" risk_description = "No predicted resistance - Standard treatment likely effective" summary = { "risk_level": risk_level, "risk_description": risk_description, "resistant_drugs": resistant_drugs, "susceptible_drugs": susceptible_drugs, "total_tested": len(self.drug_classes), } # Add probability-based confidence if available if "probabilities" in predictions: probs = predictions["probabilities"] high_confidence = [ drug for drug, prob in probs.items() if prob > 0.8 or prob < 0.2 ] summary["high_confidence_predictions"] = len(high_confidence) summary["average_confidence"] = np.mean([ max(p, 1-p) for p in probs.values() ]) return summary @property def model_info(self) -> Dict: """Get information about the loaded model.""" info = { "model_path": str(self.model_path), "device": str(self.device), "drug_classes": self.drug_classes, "n_classes": len(self.drug_classes), "has_scaler": self.scaler is not None, "model_version": MODEL_VERSION, "has_optimal_thresholds": self.optimal_thresholds is not None, } if self.performance_metrics: info["performance"] = self.performance_metrics return info # ============================================================================= # Sklearn Model Predictor (for ensemble/traditional ML models) # ============================================================================= class SklearnAMRPredictor: """Predictor for sklearn-based AMR models.""" def __init__( self, model_path: str = "models/optimized/optimized_ensemble_stacking.joblib", ): """Initialize sklearn predictor. Args: model_path: Path to joblib model file """ self.model_path = Path(model_path) self._load_model() def _load_model(self): """Load the sklearn model and components.""" if not self.model_path.exists(): raise FileNotFoundError(f"Model not found: {self.model_path}") data = joblib.load(self.model_path) self.model = data["model"] self.scaler = data.get("scaler") self.feature_selector = data.get("feature_selector") self.drug_classes = data.get("classes", DeepAMRPredictor.DEFAULT_DRUG_CLASSES) logger.info(f"Sklearn model loaded from {self.model_path}") def predict( self, features: np.ndarray, return_probabilities: bool = True, ) -> Dict: """Make predictions with sklearn model.""" if features.ndim == 1: features = features.reshape(1, -1) # Preprocess if self.scaler is not None: features = self.scaler.transform(features) if self.feature_selector is not None: if isinstance(self.feature_selector, np.ndarray): features = features[:, self.feature_selector] else: features = self.feature_selector.transform(features) # Predict predictions = self.model.predict(features) results = [] for i in range(len(predictions)): preds = predictions[i] result = { "predictions": { drug: bool(preds[j]) for j, drug in enumerate(self.drug_classes) }, "resistant_count": int(preds.sum()), "susceptible_count": int(len(self.drug_classes) - preds.sum()), } if return_probabilities and hasattr(self.model, "predict_proba"): try: probs = self.model.predict_proba(features[i:i+1])[0] result["probabilities"] = { drug: float(probs[j]) for j, drug in enumerate(self.drug_classes) } except Exception: pass results.append(result) return results[0] if len(results) == 1 else results @property def model_info(self) -> Dict: return { "model_path": str(self.model_path), "device": "cpu", "drug_classes": self.drug_classes, "n_classes": len(self.drug_classes), "has_scaler": self.scaler is not None, } # ============================================================================= # Unified Predictor Factory # ============================================================================= def get_predictor( model_type: str = "deep_learning", model_path: Optional[str] = None, ) -> Union[DeepAMRPredictor, SklearnAMRPredictor]: """Factory function to get the appropriate predictor. Args: model_type: Type of model ('deep_learning', 'sklearn', 'ensemble') model_path: Optional custom model path Returns: Configured predictor instance """ if model_type == "deep_learning": path = model_path or "models/advanced_deepamr_system.pt" return DeepAMRPredictor(path) elif model_type in ["sklearn", "ensemble"]: path = model_path or "models/optimized/optimized_ensemble_stacking.joblib" return SklearnAMRPredictor(path) else: raise ValueError(f"Unknown model type: {model_type}") # ============================================================================= # CLI for testing # ============================================================================= if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="DeepAMR Inference") parser.add_argument("--model", default="deep_learning", choices=["deep_learning", "sklearn"]) parser.add_argument("--test", action="store_true", help="Run test prediction") args = parser.parse_args() predictor = get_predictor(args.model) print(f"Model info: {predictor.model_info}") if args.test: # Load test data for demo import numpy as np X_test = np.load("data/processed/ncbi/ncbi_amr_X_test.npy") # Predict on first sample result = predictor.predict(X_test[0]) print("\nTest prediction:") print(f"Predictions: {result['predictions']}") print(f"Resistant: {result['resistant_count']}, Susceptible: {result['susceptible_count']}") if "probabilities" in result: print("\nProbabilities:") for drug, prob in result["probabilities"].items(): status = "R" if prob > 0.5 else "S" print(f" {drug}: {prob:.3f} ({status})")