#!/usr/bin/env python3 """Advanced Deep Learning System for Multi-Source AMR Prediction. This system integrates data from NCBI, CARD, PATRIC, and ResFinder to build a robust, high-accuracy deep learning model for AMR prediction. Architecture features: 1. Multi-Head Attention for k-mer importance weighting 2. Deep Residual Blocks for feature extraction 3. Focal Loss to handle extreme class imbalance 4. Multi-dataset alignment and normalization """ import json import logging import os from pathlib import Path from typing import Dict, List, Optional, Tuple, Union from datetime import datetime import numpy as np import pandas as pd import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler from sklearn.preprocessing import StandardScaler from sklearn.metrics import ( f1_score, roc_auc_score, hamming_loss, precision_recall_curve, ) # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) DEVICE = torch.device( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) # ============================================================================= # 1. Data Integration Engine # ============================================================================= class MultiSourceDataLoader: """Aligns and merges datasets from different sources.""" def __init__(self, data_root: str = "data/processed"): self.data_root = Path(data_root) self.sources = ["ncbi", "card", "patric", "resfinder"] self.unified_drug_classes = [ "aminoglycoside", "beta-lactam", "fosfomycin", "glycopeptide", "macrolide", "phenicol", "quinolone", "rifampicin", "sulfonamide", "tetracycline", "trimethoprim", ] self.scaler = StandardScaler() def load_and_align(self) -> Dict: """Loads all available datasets and aligns them to a unified label space.""" logger.info("Aligning multi-source datasets...") merged_data = {split: {"X": [], "y": []} for split in ["train", "val", "test"]} for source in self.sources: source_dir = self.data_root / source if not source_dir.exists(): logger.warning(f"Source directory not found: {source_dir}") continue # Find all metadata files in the source directory for meta_path in source_dir.glob("*_metadata.json"): prefix = meta_path.name.replace("_metadata.json", "") with open(meta_path) as f: meta = json.load(f) # Skip if not AMR drug class target if meta.get("target") not in [ "amr_drug_class", "drug_class", "multilabel", ]: # Some datasets might not have 'target' key but still be valid if "drug_classes" not in meta and "class_names" not in meta: continue source_classes = meta.get("class_names", []) or meta.get( "drug_classes", [] ) class_map = { cls: self.unified_drug_classes.index(cls) for cls in source_classes if cls in self.unified_drug_classes } if not class_map: continue logger.info( f"Merging {source} ({prefix}) with {len(class_map)} matching classes" ) for split in ["train", "val", "test"]: x_path = source_dir / f"{prefix}_X_{split}.npy" y_path = source_dir / f"{prefix}_y_{split}.npy" if not (x_path.exists() and y_path.exists()): continue X = np.load(x_path) y_orig = np.load(y_path) # Align Y to unified label space y_aligned = np.zeros((len(y_orig), len(self.unified_drug_classes))) if y_orig.ndim == 1: # Multiclass (one label per sample) - convert to multi-label format for i, label_idx in enumerate(y_orig): if label_idx < len(source_classes): cls = source_classes[label_idx] if cls in class_map: y_aligned[i, class_map[cls]] = 1 else: # Already multi-label for old_idx, cls in enumerate(source_classes): if cls in class_map: new_idx = class_map[cls] y_aligned[:, new_idx] = y_orig[:, old_idx] merged_data[split]["X"].append(X) merged_data[split]["y"].append(y_aligned) # Concatenate all sources final_data = {} for split in ["train", "val", "test"]: if not merged_data[split]["X"]: raise RuntimeError(f"No data found for {split} split") final_data[f"X_{split}"] = np.vstack(merged_data[split]["X"]) final_data[f"y_{split}"] = np.vstack(merged_data[split]["y"]) # Scale features final_data["X_train"] = self.scaler.fit_transform(final_data["X_train"]) final_data["X_val"] = self.scaler.transform(final_data["X_val"]) final_data["X_test"] = self.scaler.transform(final_data["X_test"]) logger.info( f"Unified dataset created: {len(final_data['X_train'])} training samples" ) return final_data # ============================================================================= # 2. Advanced Neural Architecture # ============================================================================= class MultiHeadAttention(nn.Module): def __init__(self, d_model, n_heads): super().__init__() self.n_heads = n_heads self.d_head = d_model // n_heads self.q_linear = nn.Linear(d_model, d_model) self.k_linear = nn.Linear(d_model, d_model) self.v_linear = nn.Linear(d_model, d_model) self.out_linear = nn.Linear(d_model, d_model) def forward(self, x): # x shape: (batch, seq_len, d_model) - here seq_len=1 for k-mer vector # Treat as (batch, d_model) -> (batch, 1, d_model) x = x.unsqueeze(1) batch_size = x.size(0) q = ( self.q_linear(x) .view(batch_size, -1, self.n_heads, self.d_head) .transpose(1, 2) ) k = ( self.k_linear(x) .view(batch_size, -1, self.n_heads, self.d_head) .transpose(1, 2) ) v = ( self.v_linear(x) .view(batch_size, -1, self.n_heads, self.d_head) .transpose(1, 2) ) scores = torch.matmul(q, k.transpose(-2, -1)) / np.sqrt(self.d_head) attn = F.softmax(scores, dim=-1) context = ( torch.matmul(attn, v) .transpose(1, 2) .contiguous() .view(batch_size, -1, self.n_heads * self.d_head) ) return self.out_linear(context).squeeze(1) class ResidualBlock(nn.Module): def __init__(self, dim, dropout=0.2): super().__init__() self.net = nn.Sequential( nn.Linear(dim, dim), nn.BatchNorm1d(dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim, dim), nn.BatchNorm1d(dim), ) self.gelu = nn.GELU() def forward(self, x): return self.gelu(x + self.net(x)) class AdvancedDeepAMR(nn.Module): """Advanced Deep Learning Model for AMR Prediction.""" def __init__(self, input_dim, output_dim, hidden_dim=512, n_blocks=4): super().__init__() self.embedding = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.GELU() ) self.attention = MultiHeadAttention(hidden_dim, n_heads=8) self.res_blocks = nn.ModuleList( [ResidualBlock(hidden_dim) for _ in range(n_blocks)] ) self.classifier = nn.Sequential( nn.Linear(hidden_dim, hidden_dim // 2), nn.GELU(), nn.Dropout(0.3), nn.Linear(hidden_dim // 2, output_dim), ) def forward(self, x): x = self.embedding(x) x = x + self.attention(x) for block in self.res_blocks: x = block(x) return self.classifier(x) # ============================================================================= # 3. Training Logic (Focal Loss + Weighted Sampling) # ============================================================================= class FocalLoss(nn.Module): def __init__(self, alpha=1.0, gamma=2.0): super().__init__() self.alpha = alpha self.gamma = float(gamma) def forward(self, inputs, targets): BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") pt = torch.exp(-BCE_loss) F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss return F_loss.mean() def train_advanced_system(): logger.info("Starting Advanced DL System Training...") # 1. Load and Merge Data loader = MultiSourceDataLoader() data = loader.load_and_align() X_train, y_train = ( torch.FloatTensor(data["X_train"]), torch.FloatTensor(data["y_train"]), ) X_val, y_val = torch.FloatTensor(data["X_val"]), torch.FloatTensor(data["y_val"]) X_test, y_test = ( torch.FloatTensor(data["X_test"]), torch.FloatTensor(data["y_test"]), ) train_ds = TensorDataset(X_train, y_train) val_loader = DataLoader(TensorDataset(X_val, y_val), batch_size=64) test_loader = DataLoader(TensorDataset(X_test, y_test), batch_size=64) # Handle multi-label imbalance via batch sampling # (Simplified: using random shuffle, focusing on Focal Loss for imbalance) train_loader = DataLoader(train_ds, batch_size=64, shuffle=True) # 2. Build Model model = AdvancedDeepAMR(input_dim=X_train.shape[1], output_dim=y_train.shape[1]).to( DEVICE ) optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50) criterion = FocalLoss(gamma=2.5) # High gamma for extreme imbalance best_f1 = 0 epochs = 100 patience = 15 counter = 0 logger.info(f"Model Training on {DEVICE}...") for epoch in range(epochs): model.train() total_loss = 0 for batch_X, batch_y in train_loader: batch_X, batch_y = batch_X.to(DEVICE), batch_y.to(DEVICE) optimizer.zero_grad() outputs = model(batch_X) loss = criterion(outputs, batch_y) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() total_loss += loss.item() scheduler.step() # Validation model.eval() all_preds = [] all_targets = [] with torch.no_grad(): for batch_X, batch_y in val_loader: batch_X = batch_X.to(DEVICE) logits = model(batch_X) preds = (torch.sigmoid(logits) > 0.5).cpu().numpy() all_preds.append(preds) all_targets.append(batch_y.numpy()) y_pred = np.vstack(all_preds) y_true = np.vstack(all_targets) val_f1 = f1_score(y_true, y_pred, average="micro", zero_division=0) if (epoch + 1) % 5 == 0: logger.info( f"Epoch {epoch + 1:03d} | Train Loss: {total_loss / len(train_loader):.4f} | Val F1: {val_f1:.4f}" ) if val_f1 > best_f1: best_f1 = val_f1 Path("models").mkdir(parents=True, exist_ok=True) torch.save( { "model_state_dict": model.state_dict(), "scaler": loader.scaler, "classes": loader.unified_drug_classes, "input_dim": X_train.shape[1], "output_dim": y_train.shape[1], "hidden_dim": 512, "n_blocks": 4, "best_val_f1": best_f1, "epoch": epoch + 1, }, "models/advanced_deepamr_system.pt", ) counter = 0 else: counter += 1 if counter >= patience: logger.info("Early stopping triggered.") break # Final Test Evaluation checkpoint = torch.load("models/advanced_deepamr_system.pt", weights_only=False) model.load_state_dict(checkpoint["model_state_dict"]) model.eval() test_preds = [] test_targets = [] with torch.no_grad(): for batch_X, batch_y in test_loader: batch_X = batch_X.to(DEVICE) logits = model(batch_X) probs = torch.sigmoid(logits).cpu().numpy() test_preds.append(probs) test_targets.append(batch_y.numpy()) y_prob = np.vstack(test_preds) y_pred = (y_prob > 0.5).astype(int) y_true = np.vstack(test_targets) metrics = { "micro_f1": f1_score(y_true, y_pred, average="micro"), "macro_f1": f1_score(y_true, y_pred, average="macro"), "hamming_loss": hamming_loss(y_true, y_pred), "micro_auc": roc_auc_score(y_true, y_prob, average="micro"), } logger.info("=" * 60) logger.info("FINAL SYSTEM PERFORMANCE (Unified Multi-Source)") logger.info(f"Micro F1: {metrics['micro_f1']:.4f}") logger.info(f"Macro F1: {metrics['macro_f1']:.4f}") logger.info(f"Micro AUC: {metrics['micro_auc']:.4f}") logger.info(f"Hamming Loss: {metrics['hamming_loss']:.4f}") logger.info("=" * 60) # Save detailed results with open("models/advanced_system_results.json", "w") as f: json.dump(metrics, f, indent=2) if __name__ == "__main__": train_advanced_system()