Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| DeepAMR: Deep Learning Models for Antimicrobial Resistance Prediction | |
| This script trains deep learning models for: | |
| 1. Organism Classification (multiclass) | |
| 2. AMR Drug Resistance Prediction (multilabel) | |
| Designed for high-impact deployment in Bangladesh healthcare systems. | |
| Usage: | |
| python src/ml/deep_learning_trainer.py --task organism | |
| python src/ml/deep_learning_trainer.py --task amr | |
| python src/ml/deep_learning_trainer.py --task both | |
| """ | |
| import argparse | |
| import json | |
| import logging | |
| import os | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader, TensorDataset | |
| from sklearn.metrics import ( | |
| accuracy_score, | |
| classification_report, | |
| f1_score, | |
| precision_score, | |
| recall_score, | |
| roc_auc_score, | |
| confusion_matrix, | |
| ) | |
| from sklearn.preprocessing import StandardScaler | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Set device | |
| DEVICE = torch.device('cuda' if torch.cuda.is_available() else | |
| 'mps' if torch.backends.mps.is_available() else 'cpu') | |
| logger.info(f"Using device: {DEVICE}") | |
| # ============================================================================= | |
| # Neural Network Architectures | |
| # ============================================================================= | |
| class OrganismClassifier(nn.Module): | |
| """Deep neural network for organism classification from k-mer features.""" | |
| def __init__( | |
| self, | |
| input_size: int, | |
| hidden_sizes: List[int] = [256, 128, 64], | |
| num_classes: int = 8, | |
| dropout_rate: float = 0.3, | |
| ): | |
| super().__init__() | |
| layers = [] | |
| prev_size = input_size | |
| for hidden_size in hidden_sizes: | |
| layers.extend([ | |
| nn.Linear(prev_size, hidden_size), | |
| nn.BatchNorm1d(hidden_size), | |
| nn.ReLU(), | |
| nn.Dropout(dropout_rate), | |
| ]) | |
| prev_size = hidden_size | |
| layers.append(nn.Linear(prev_size, num_classes)) | |
| self.network = nn.Sequential(*layers) | |
| # Initialize weights | |
| self._init_weights() | |
| def _init_weights(self): | |
| for m in self.modules(): | |
| if isinstance(m, nn.Linear): | |
| nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.BatchNorm1d): | |
| nn.init.constant_(m.weight, 1) | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.network(x) | |
| class AMRPredictor(nn.Module): | |
| """Deep neural network for multi-label AMR prediction.""" | |
| def __init__( | |
| self, | |
| input_size: int, | |
| hidden_sizes: List[int] = [512, 256, 128], | |
| num_classes: int = 11, | |
| dropout_rate: float = 0.4, | |
| ): | |
| super().__init__() | |
| # Shared feature extractor | |
| shared_layers = [] | |
| prev_size = input_size | |
| for i, hidden_size in enumerate(hidden_sizes[:-1]): | |
| shared_layers.extend([ | |
| nn.Linear(prev_size, hidden_size), | |
| nn.BatchNorm1d(hidden_size), | |
| nn.LeakyReLU(0.1), | |
| nn.Dropout(dropout_rate), | |
| ]) | |
| prev_size = hidden_size | |
| self.shared = nn.Sequential(*shared_layers) | |
| # Drug-class specific heads for better performance | |
| self.drug_heads = nn.ModuleList([ | |
| nn.Sequential( | |
| nn.Linear(prev_size, hidden_sizes[-1]), | |
| nn.BatchNorm1d(hidden_sizes[-1]), | |
| nn.LeakyReLU(0.1), | |
| nn.Dropout(dropout_rate * 0.5), | |
| nn.Linear(hidden_sizes[-1], 1), | |
| ) | |
| for _ in range(num_classes) | |
| ]) | |
| self._init_weights() | |
| def _init_weights(self): | |
| for m in self.modules(): | |
| if isinstance(m, nn.Linear): | |
| nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.BatchNorm1d): | |
| nn.init.constant_(m.weight, 1) | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| shared_features = self.shared(x) | |
| outputs = [head(shared_features) for head in self.drug_heads] | |
| return torch.cat(outputs, dim=1) | |
| class ResidualBlock(nn.Module): | |
| """Residual block for deeper networks.""" | |
| def __init__(self, size: int, dropout_rate: float = 0.3): | |
| super().__init__() | |
| self.block = nn.Sequential( | |
| nn.Linear(size, size), | |
| nn.BatchNorm1d(size), | |
| nn.ReLU(), | |
| nn.Dropout(dropout_rate), | |
| nn.Linear(size, size), | |
| nn.BatchNorm1d(size), | |
| ) | |
| self.relu = nn.ReLU() | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.relu(x + self.block(x)) | |
| class DeepAMRNet(nn.Module): | |
| """Advanced deep network with residual connections for AMR prediction.""" | |
| def __init__( | |
| self, | |
| input_size: int, | |
| hidden_size: int = 256, | |
| num_residual_blocks: int = 3, | |
| num_classes: int = 11, | |
| dropout_rate: float = 0.3, | |
| ): | |
| super().__init__() | |
| self.input_layer = nn.Sequential( | |
| nn.Linear(input_size, hidden_size), | |
| nn.BatchNorm1d(hidden_size), | |
| nn.ReLU(), | |
| nn.Dropout(dropout_rate), | |
| ) | |
| self.residual_blocks = nn.Sequential( | |
| *[ResidualBlock(hidden_size, dropout_rate) for _ in range(num_residual_blocks)] | |
| ) | |
| self.output_layer = nn.Sequential( | |
| nn.Linear(hidden_size, hidden_size // 2), | |
| nn.BatchNorm1d(hidden_size // 2), | |
| nn.ReLU(), | |
| nn.Dropout(dropout_rate * 0.5), | |
| nn.Linear(hidden_size // 2, num_classes), | |
| ) | |
| self._init_weights() | |
| def _init_weights(self): | |
| for m in self.modules(): | |
| if isinstance(m, nn.Linear): | |
| nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.input_layer(x) | |
| x = self.residual_blocks(x) | |
| return self.output_layer(x) | |
| # ============================================================================= | |
| # Training Utilities | |
| # ============================================================================= | |
| class EarlyStopping: | |
| """Early stopping to prevent overfitting.""" | |
| def __init__(self, patience: int = 10, min_delta: float = 0.001, mode: str = 'min'): | |
| self.patience = patience | |
| self.min_delta = min_delta | |
| self.mode = mode | |
| self.counter = 0 | |
| self.best_score = None | |
| self.early_stop = False | |
| def __call__(self, score: float) -> bool: | |
| if self.best_score is None: | |
| self.best_score = score | |
| elif self._is_improvement(score): | |
| self.best_score = score | |
| self.counter = 0 | |
| else: | |
| self.counter += 1 | |
| if self.counter >= self.patience: | |
| self.early_stop = True | |
| return self.early_stop | |
| def _is_improvement(self, score: float) -> bool: | |
| if self.mode == 'min': | |
| return score < self.best_score - self.min_delta | |
| return score > self.best_score + self.min_delta | |
| class FocalLoss(nn.Module): | |
| """Focal Loss for handling class imbalance in multilabel classification.""" | |
| def __init__(self, alpha: float = 0.25, gamma: float = 2.0): | |
| super().__init__() | |
| self.alpha = alpha | |
| self.gamma = gamma | |
| def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: | |
| bce_loss = nn.functional.binary_cross_entropy_with_logits( | |
| inputs, targets, reduction='none' | |
| ) | |
| pt = torch.exp(-bce_loss) | |
| focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss | |
| return focal_loss.mean() | |
| def compute_class_weights(y: np.ndarray, task: str = 'multiclass') -> torch.Tensor: | |
| """Compute class weights to handle imbalanced data.""" | |
| if task == 'multiclass': | |
| class_counts = np.bincount(y) | |
| total = len(y) | |
| weights = total / (len(class_counts) * class_counts) | |
| return torch.FloatTensor(weights) | |
| else: # multilabel | |
| pos_counts = y.sum(axis=0) | |
| neg_counts = len(y) - pos_counts | |
| weights = neg_counts / (pos_counts + 1e-6) | |
| weights = np.clip(weights, 1.0, 10.0) # Clip extreme weights | |
| return torch.FloatTensor(weights) | |
| # ============================================================================= | |
| # Trainer Classes | |
| # ============================================================================= | |
| class BaseTrainer: | |
| """Base trainer class with common functionality.""" | |
| def __init__( | |
| self, | |
| model: nn.Module, | |
| train_loader: DataLoader, | |
| val_loader: DataLoader, | |
| test_loader: DataLoader, | |
| learning_rate: float = 0.001, | |
| weight_decay: float = 0.01, | |
| device: torch.device = DEVICE, | |
| ): | |
| self.model = model.to(device) | |
| self.train_loader = train_loader | |
| self.val_loader = val_loader | |
| self.test_loader = test_loader | |
| self.device = device | |
| self.optimizer = optim.AdamW( | |
| model.parameters(), | |
| lr=learning_rate, | |
| weight_decay=weight_decay, | |
| ) | |
| self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( | |
| self.optimizer, | |
| mode='min', | |
| factor=0.5, | |
| patience=5, | |
| ) | |
| self.history = { | |
| 'train_loss': [], | |
| 'val_loss': [], | |
| 'train_metrics': [], | |
| 'val_metrics': [], | |
| 'learning_rates': [], | |
| } | |
| def save_checkpoint(self, path: str, epoch: int, metrics: Dict): | |
| """Save model checkpoint.""" | |
| checkpoint = { | |
| 'epoch': epoch, | |
| 'model_state_dict': self.model.state_dict(), | |
| 'optimizer_state_dict': self.optimizer.state_dict(), | |
| 'scheduler_state_dict': self.scheduler.state_dict(), | |
| 'metrics': metrics, | |
| 'history': self.history, | |
| } | |
| torch.save(checkpoint, path) | |
| logger.info(f"Checkpoint saved to {path}") | |
| def load_checkpoint(self, path: str): | |
| """Load model checkpoint.""" | |
| checkpoint = torch.load(path, map_location=self.device) | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
| self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) | |
| self.history = checkpoint['history'] | |
| return checkpoint['epoch'], checkpoint['metrics'] | |
| class OrganismTrainer(BaseTrainer): | |
| """Trainer for organism classification.""" | |
| def __init__( | |
| self, | |
| model: nn.Module, | |
| train_loader: DataLoader, | |
| val_loader: DataLoader, | |
| test_loader: DataLoader, | |
| class_weights: Optional[torch.Tensor] = None, | |
| **kwargs, | |
| ): | |
| super().__init__(model, train_loader, val_loader, test_loader, **kwargs) | |
| if class_weights is not None: | |
| class_weights = class_weights.to(self.device) | |
| self.criterion = nn.CrossEntropyLoss(weight=class_weights) | |
| def train_epoch(self) -> Tuple[float, Dict]: | |
| """Train for one epoch.""" | |
| self.model.train() | |
| total_loss = 0 | |
| all_preds = [] | |
| all_labels = [] | |
| for batch_x, batch_y in self.train_loader: | |
| batch_x = batch_x.to(self.device) | |
| batch_y = batch_y.to(self.device) | |
| self.optimizer.zero_grad() | |
| outputs = self.model(batch_x) | |
| loss = self.criterion(outputs, batch_y) | |
| loss.backward() | |
| # Gradient clipping | |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) | |
| self.optimizer.step() | |
| total_loss += loss.item() | |
| preds = outputs.argmax(dim=1).cpu().numpy() | |
| all_preds.extend(preds) | |
| all_labels.extend(batch_y.cpu().numpy()) | |
| avg_loss = total_loss / len(self.train_loader) | |
| metrics = { | |
| 'accuracy': accuracy_score(all_labels, all_preds), | |
| 'f1_macro': f1_score(all_labels, all_preds, average='macro'), | |
| 'f1_weighted': f1_score(all_labels, all_preds, average='weighted'), | |
| } | |
| return avg_loss, metrics | |
| def validate(self, loader: DataLoader) -> Tuple[float, Dict]: | |
| """Validate the model.""" | |
| self.model.eval() | |
| total_loss = 0 | |
| all_preds = [] | |
| all_labels = [] | |
| all_probs = [] | |
| with torch.no_grad(): | |
| for batch_x, batch_y in loader: | |
| batch_x = batch_x.to(self.device) | |
| batch_y = batch_y.to(self.device) | |
| outputs = self.model(batch_x) | |
| loss = self.criterion(outputs, batch_y) | |
| total_loss += loss.item() | |
| probs = torch.softmax(outputs, dim=1).cpu().numpy() | |
| preds = outputs.argmax(dim=1).cpu().numpy() | |
| all_probs.extend(probs) | |
| all_preds.extend(preds) | |
| all_labels.extend(batch_y.cpu().numpy()) | |
| avg_loss = total_loss / len(loader) | |
| all_probs = np.array(all_probs) | |
| metrics = { | |
| 'accuracy': accuracy_score(all_labels, all_preds), | |
| 'f1_macro': f1_score(all_labels, all_preds, average='macro'), | |
| 'f1_weighted': f1_score(all_labels, all_preds, average='weighted'), | |
| 'precision_macro': precision_score(all_labels, all_preds, average='macro'), | |
| 'recall_macro': recall_score(all_labels, all_preds, average='macro'), | |
| } | |
| # ROC-AUC for multiclass | |
| try: | |
| metrics['roc_auc'] = roc_auc_score( | |
| all_labels, all_probs, multi_class='ovr', average='macro' | |
| ) | |
| except ValueError: | |
| metrics['roc_auc'] = 0.0 | |
| return avg_loss, metrics, all_preds, all_labels, all_probs | |
| def train( | |
| self, | |
| epochs: int = 100, | |
| patience: int = 15, | |
| save_path: str = 'models/organism_classifier.pt', | |
| ) -> Dict: | |
| """Full training loop.""" | |
| early_stopping = EarlyStopping(patience=patience, mode='max') | |
| best_f1 = 0 | |
| logger.info("Starting organism classification training...") | |
| logger.info(f"Training samples: {len(self.train_loader.dataset)}") | |
| logger.info(f"Validation samples: {len(self.val_loader.dataset)}") | |
| for epoch in range(epochs): | |
| # Train | |
| train_loss, train_metrics = self.train_epoch() | |
| # Validate | |
| val_loss, val_metrics, _, _, _ = self.validate(self.val_loader) | |
| # Update scheduler | |
| self.scheduler.step(val_loss) | |
| # Record history | |
| self.history['train_loss'].append(train_loss) | |
| self.history['val_loss'].append(val_loss) | |
| self.history['train_metrics'].append(train_metrics) | |
| self.history['val_metrics'].append(val_metrics) | |
| self.history['learning_rates'].append(self.optimizer.param_groups[0]['lr']) | |
| # Logging | |
| logger.info( | |
| f"Epoch {epoch+1}/{epochs} | " | |
| f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | " | |
| f"Train Acc: {train_metrics['accuracy']:.4f} | " | |
| f"Val Acc: {val_metrics['accuracy']:.4f} | " | |
| f"Val F1: {val_metrics['f1_macro']:.4f}" | |
| ) | |
| # Save best model | |
| if val_metrics['f1_macro'] > best_f1: | |
| best_f1 = val_metrics['f1_macro'] | |
| Path(save_path).parent.mkdir(parents=True, exist_ok=True) | |
| self.save_checkpoint(save_path, epoch, val_metrics) | |
| logger.info(f"New best model saved! F1: {best_f1:.4f}") | |
| # Early stopping | |
| if early_stopping(val_metrics['f1_macro']): | |
| logger.info(f"Early stopping triggered at epoch {epoch+1}") | |
| break | |
| # Final evaluation on test set | |
| logger.info("\nEvaluating on test set...") | |
| test_loss, test_metrics, test_preds, test_labels, test_probs = self.validate( | |
| self.test_loader | |
| ) | |
| logger.info(f"\nTest Results:") | |
| logger.info(f" Accuracy: {test_metrics['accuracy']:.4f}") | |
| logger.info(f" F1 (macro): {test_metrics['f1_macro']:.4f}") | |
| logger.info(f" F1 (weighted): {test_metrics['f1_weighted']:.4f}") | |
| logger.info(f" ROC-AUC: {test_metrics['roc_auc']:.4f}") | |
| return { | |
| 'history': self.history, | |
| 'test_metrics': test_metrics, | |
| 'test_predictions': test_preds, | |
| 'test_labels': test_labels, | |
| 'test_probabilities': test_probs, | |
| } | |
| class AMRTrainer(BaseTrainer): | |
| """Trainer for multilabel AMR prediction.""" | |
| def __init__( | |
| self, | |
| model: nn.Module, | |
| train_loader: DataLoader, | |
| val_loader: DataLoader, | |
| test_loader: DataLoader, | |
| pos_weights: Optional[torch.Tensor] = None, | |
| use_focal_loss: bool = True, | |
| **kwargs, | |
| ): | |
| super().__init__(model, train_loader, val_loader, test_loader, **kwargs) | |
| if use_focal_loss: | |
| self.criterion = FocalLoss(alpha=0.25, gamma=2.0) | |
| else: | |
| if pos_weights is not None: | |
| pos_weights = pos_weights.to(self.device) | |
| self.criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weights) | |
| def train_epoch(self) -> Tuple[float, Dict]: | |
| """Train for one epoch.""" | |
| self.model.train() | |
| total_loss = 0 | |
| all_preds = [] | |
| all_labels = [] | |
| for batch_x, batch_y in self.train_loader: | |
| batch_x = batch_x.to(self.device) | |
| batch_y = batch_y.float().to(self.device) | |
| self.optimizer.zero_grad() | |
| outputs = self.model(batch_x) | |
| loss = self.criterion(outputs, batch_y) | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) | |
| self.optimizer.step() | |
| total_loss += loss.item() | |
| preds = (torch.sigmoid(outputs) > 0.5).cpu().numpy() | |
| all_preds.extend(preds) | |
| all_labels.extend(batch_y.cpu().numpy()) | |
| avg_loss = total_loss / len(self.train_loader) | |
| all_preds = np.array(all_preds) | |
| all_labels = np.array(all_labels) | |
| metrics = { | |
| 'f1_micro': f1_score(all_labels, all_preds, average='micro'), | |
| 'f1_macro': f1_score(all_labels, all_preds, average='macro'), | |
| 'f1_samples': f1_score(all_labels, all_preds, average='samples'), | |
| } | |
| return avg_loss, metrics | |
| def validate(self, loader: DataLoader) -> Tuple[float, Dict]: | |
| """Validate the model.""" | |
| self.model.eval() | |
| total_loss = 0 | |
| all_preds = [] | |
| all_labels = [] | |
| all_probs = [] | |
| with torch.no_grad(): | |
| for batch_x, batch_y in loader: | |
| batch_x = batch_x.to(self.device) | |
| batch_y = batch_y.float().to(self.device) | |
| outputs = self.model(batch_x) | |
| loss = self.criterion(outputs, batch_y) | |
| total_loss += loss.item() | |
| probs = torch.sigmoid(outputs).cpu().numpy() | |
| preds = (probs > 0.5).astype(int) | |
| all_probs.extend(probs) | |
| all_preds.extend(preds) | |
| all_labels.extend(batch_y.cpu().numpy()) | |
| avg_loss = total_loss / len(loader) | |
| all_preds = np.array(all_preds) | |
| all_labels = np.array(all_labels) | |
| all_probs = np.array(all_probs) | |
| metrics = { | |
| 'f1_micro': f1_score(all_labels, all_preds, average='micro'), | |
| 'f1_macro': f1_score(all_labels, all_preds, average='macro'), | |
| 'f1_samples': f1_score(all_labels, all_preds, average='samples'), | |
| 'precision_micro': precision_score(all_labels, all_preds, average='micro'), | |
| 'recall_micro': recall_score(all_labels, all_preds, average='micro'), | |
| } | |
| # Per-class metrics | |
| per_class_f1 = f1_score(all_labels, all_preds, average=None) | |
| metrics['per_class_f1'] = per_class_f1.tolist() | |
| # ROC-AUC | |
| try: | |
| metrics['roc_auc_micro'] = roc_auc_score(all_labels, all_probs, average='micro') | |
| metrics['roc_auc_macro'] = roc_auc_score(all_labels, all_probs, average='macro') | |
| except ValueError: | |
| metrics['roc_auc_micro'] = 0.0 | |
| metrics['roc_auc_macro'] = 0.0 | |
| return avg_loss, metrics, all_preds, all_labels, all_probs | |
| def train( | |
| self, | |
| epochs: int = 100, | |
| patience: int = 15, | |
| save_path: str = 'models/amr_predictor.pt', | |
| ) -> Dict: | |
| """Full training loop.""" | |
| early_stopping = EarlyStopping(patience=patience, mode='max') | |
| best_f1 = 0 | |
| logger.info("Starting AMR prediction training...") | |
| logger.info(f"Training samples: {len(self.train_loader.dataset)}") | |
| logger.info(f"Validation samples: {len(self.val_loader.dataset)}") | |
| for epoch in range(epochs): | |
| # Train | |
| train_loss, train_metrics = self.train_epoch() | |
| # Validate | |
| val_loss, val_metrics, _, _, _ = self.validate(self.val_loader) | |
| # Update scheduler | |
| self.scheduler.step(val_loss) | |
| # Record history | |
| self.history['train_loss'].append(train_loss) | |
| self.history['val_loss'].append(val_loss) | |
| self.history['train_metrics'].append(train_metrics) | |
| self.history['val_metrics'].append(val_metrics) | |
| self.history['learning_rates'].append(self.optimizer.param_groups[0]['lr']) | |
| # Logging | |
| logger.info( | |
| f"Epoch {epoch+1}/{epochs} | " | |
| f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | " | |
| f"Train F1: {train_metrics['f1_macro']:.4f} | " | |
| f"Val F1: {val_metrics['f1_macro']:.4f} | " | |
| f"Val AUC: {val_metrics.get('roc_auc_macro', 0):.4f}" | |
| ) | |
| # Save best model | |
| if val_metrics['f1_macro'] > best_f1: | |
| best_f1 = val_metrics['f1_macro'] | |
| Path(save_path).parent.mkdir(parents=True, exist_ok=True) | |
| self.save_checkpoint(save_path, epoch, val_metrics) | |
| logger.info(f"New best model saved! F1: {best_f1:.4f}") | |
| # Early stopping | |
| if early_stopping(val_metrics['f1_macro']): | |
| logger.info(f"Early stopping triggered at epoch {epoch+1}") | |
| break | |
| # Final evaluation on test set | |
| logger.info("\nEvaluating on test set...") | |
| test_loss, test_metrics, test_preds, test_labels, test_probs = self.validate( | |
| self.test_loader | |
| ) | |
| logger.info(f"\nTest Results:") | |
| logger.info(f" F1 (micro): {test_metrics['f1_micro']:.4f}") | |
| logger.info(f" F1 (macro): {test_metrics['f1_macro']:.4f}") | |
| logger.info(f" F1 (samples): {test_metrics['f1_samples']:.4f}") | |
| logger.info(f" ROC-AUC (macro): {test_metrics['roc_auc_macro']:.4f}") | |
| return { | |
| 'history': self.history, | |
| 'test_metrics': test_metrics, | |
| 'test_predictions': test_preds, | |
| 'test_labels': test_labels, | |
| 'test_probabilities': test_probs, | |
| } | |
| # ============================================================================= | |
| # Data Loading | |
| # ============================================================================= | |
| def load_data(task: str = 'organism') -> Tuple: | |
| """Load preprocessed data for training.""" | |
| data_dir = Path('data/processed/ncbi') | |
| if task == 'organism': | |
| prefix = 'ncbi_organism' | |
| else: | |
| prefix = 'ncbi_amr' | |
| X_train = np.load(data_dir / f'{prefix}_X_train.npy') | |
| X_val = np.load(data_dir / f'{prefix}_X_val.npy') | |
| X_test = np.load(data_dir / f'{prefix}_X_test.npy') | |
| y_train = np.load(data_dir / f'{prefix}_y_train.npy') | |
| y_val = np.load(data_dir / f'{prefix}_y_val.npy') | |
| y_test = np.load(data_dir / f'{prefix}_y_test.npy') | |
| with open(data_dir / f'{prefix}_metadata.json') as f: | |
| metadata = json.load(f) | |
| logger.info(f"Loaded {task} data:") | |
| logger.info(f" Train: {X_train.shape}, Val: {X_val.shape}, Test: {X_test.shape}") | |
| return X_train, X_val, X_test, y_train, y_val, y_test, metadata | |
| def create_dataloaders( | |
| X_train: np.ndarray, | |
| X_val: np.ndarray, | |
| X_test: np.ndarray, | |
| y_train: np.ndarray, | |
| y_val: np.ndarray, | |
| y_test: np.ndarray, | |
| batch_size: int = 32, | |
| normalize: bool = True, | |
| ) -> Tuple[DataLoader, DataLoader, DataLoader, Optional[StandardScaler]]: | |
| """Create PyTorch DataLoaders.""" | |
| scaler = None | |
| if normalize: | |
| scaler = StandardScaler() | |
| X_train = scaler.fit_transform(X_train) | |
| X_val = scaler.transform(X_val) | |
| X_test = scaler.transform(X_test) | |
| train_dataset = TensorDataset( | |
| torch.FloatTensor(X_train), | |
| torch.LongTensor(y_train) if y_train.ndim == 1 else torch.FloatTensor(y_train), | |
| ) | |
| val_dataset = TensorDataset( | |
| torch.FloatTensor(X_val), | |
| torch.LongTensor(y_val) if y_val.ndim == 1 else torch.FloatTensor(y_val), | |
| ) | |
| test_dataset = TensorDataset( | |
| torch.FloatTensor(X_test), | |
| torch.LongTensor(y_test) if y_test.ndim == 1 else torch.FloatTensor(y_test), | |
| ) | |
| train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | |
| val_loader = DataLoader(val_dataset, batch_size=batch_size) | |
| test_loader = DataLoader(test_dataset, batch_size=batch_size) | |
| return train_loader, val_loader, test_loader, scaler | |
| # ============================================================================= | |
| # Main Training Functions | |
| # ============================================================================= | |
| def train_organism_classifier( | |
| epochs: int = 100, | |
| batch_size: int = 32, | |
| learning_rate: float = 0.001, | |
| hidden_sizes: List[int] = [256, 128, 64], | |
| dropout_rate: float = 0.3, | |
| save_dir: str = 'models', | |
| ) -> Dict: | |
| """Train organism classification model.""" | |
| logger.info("=" * 60) | |
| logger.info("ORGANISM CLASSIFICATION TRAINING") | |
| logger.info("=" * 60) | |
| # Load data | |
| X_train, X_val, X_test, y_train, y_val, y_test, metadata = load_data('organism') | |
| # Create dataloaders | |
| train_loader, val_loader, test_loader, scaler = create_dataloaders( | |
| X_train, X_val, X_test, y_train, y_val, y_test, batch_size | |
| ) | |
| # Compute class weights | |
| class_weights = compute_class_weights(y_train, 'multiclass') | |
| # Create model | |
| model = OrganismClassifier( | |
| input_size=X_train.shape[1], | |
| hidden_sizes=hidden_sizes, | |
| num_classes=len(metadata['class_names']), | |
| dropout_rate=dropout_rate, | |
| ) | |
| logger.info(f"Model architecture:\n{model}") | |
| logger.info(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}") | |
| # Create trainer | |
| trainer = OrganismTrainer( | |
| model=model, | |
| train_loader=train_loader, | |
| val_loader=val_loader, | |
| test_loader=test_loader, | |
| class_weights=class_weights, | |
| learning_rate=learning_rate, | |
| ) | |
| # Train | |
| save_path = Path(save_dir) / 'organism_classifier.pt' | |
| results = trainer.train(epochs=epochs, save_path=str(save_path)) | |
| # Save scaler | |
| if scaler is not None: | |
| import joblib | |
| scaler_path = Path(save_dir) / 'organism_scaler.joblib' | |
| joblib.dump(scaler, scaler_path) | |
| logger.info(f"Scaler saved to {scaler_path}") | |
| # Save metadata and results | |
| results_path = Path(save_dir) / 'organism_results.json' | |
| save_results = { | |
| 'metadata': metadata, | |
| 'test_metrics': results['test_metrics'], | |
| 'training_config': { | |
| 'epochs': epochs, | |
| 'batch_size': batch_size, | |
| 'learning_rate': learning_rate, | |
| 'hidden_sizes': hidden_sizes, | |
| 'dropout_rate': dropout_rate, | |
| }, | |
| } | |
| with open(results_path, 'w') as f: | |
| json.dump(save_results, f, indent=2) | |
| # Save training history | |
| history_path = Path(save_dir) / 'organism_history.json' | |
| history_save = { | |
| 'train_loss': results['history']['train_loss'], | |
| 'val_loss': results['history']['val_loss'], | |
| 'train_metrics': results['history']['train_metrics'], | |
| 'val_metrics': results['history']['val_metrics'], | |
| 'learning_rates': results['history']['learning_rates'], | |
| } | |
| with open(history_path, 'w') as f: | |
| json.dump(history_save, f, indent=2) | |
| logger.info(f"\nResults saved to {save_dir}") | |
| return results | |
| def train_amr_predictor( | |
| epochs: int = 100, | |
| batch_size: int = 32, | |
| learning_rate: float = 0.001, | |
| hidden_sizes: List[int] = [512, 256, 128], | |
| dropout_rate: float = 0.4, | |
| use_focal_loss: bool = True, | |
| save_dir: str = 'models', | |
| ) -> Dict: | |
| """Train AMR prediction model.""" | |
| logger.info("=" * 60) | |
| logger.info("AMR PREDICTION TRAINING") | |
| logger.info("=" * 60) | |
| # Load data | |
| X_train, X_val, X_test, y_train, y_val, y_test, metadata = load_data('amr') | |
| # Create dataloaders | |
| train_loader, val_loader, test_loader, scaler = create_dataloaders( | |
| X_train, X_val, X_test, y_train, y_val, y_test, batch_size | |
| ) | |
| # Compute positive weights for class imbalance | |
| pos_weights = compute_class_weights(y_train, 'multilabel') | |
| # Create model | |
| model = AMRPredictor( | |
| input_size=X_train.shape[1], | |
| hidden_sizes=hidden_sizes, | |
| num_classes=len(metadata['class_names']), | |
| dropout_rate=dropout_rate, | |
| ) | |
| logger.info(f"Model architecture:\n{model}") | |
| logger.info(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}") | |
| # Create trainer | |
| trainer = AMRTrainer( | |
| model=model, | |
| train_loader=train_loader, | |
| val_loader=val_loader, | |
| test_loader=test_loader, | |
| pos_weights=pos_weights, | |
| use_focal_loss=use_focal_loss, | |
| learning_rate=learning_rate, | |
| ) | |
| # Train | |
| save_path = Path(save_dir) / 'amr_predictor.pt' | |
| results = trainer.train(epochs=epochs, save_path=str(save_path)) | |
| # Save scaler | |
| if scaler is not None: | |
| import joblib | |
| scaler_path = Path(save_dir) / 'amr_scaler.joblib' | |
| joblib.dump(scaler, scaler_path) | |
| logger.info(f"Scaler saved to {scaler_path}") | |
| # Save metadata and results | |
| results_path = Path(save_dir) / 'amr_results.json' | |
| save_results = { | |
| 'metadata': metadata, | |
| 'test_metrics': {k: v if not isinstance(v, np.ndarray) else v.tolist() | |
| for k, v in results['test_metrics'].items()}, | |
| 'training_config': { | |
| 'epochs': epochs, | |
| 'batch_size': batch_size, | |
| 'learning_rate': learning_rate, | |
| 'hidden_sizes': hidden_sizes, | |
| 'dropout_rate': dropout_rate, | |
| 'use_focal_loss': use_focal_loss, | |
| }, | |
| } | |
| with open(results_path, 'w') as f: | |
| json.dump(save_results, f, indent=2) | |
| # Save training history | |
| history_path = Path(save_dir) / 'amr_history.json' | |
| history_save = { | |
| 'train_loss': results['history']['train_loss'], | |
| 'val_loss': results['history']['val_loss'], | |
| 'train_metrics': results['history']['train_metrics'], | |
| 'val_metrics': [{k: v if not isinstance(v, list) else v | |
| for k, v in m.items()} for m in results['history']['val_metrics']], | |
| 'learning_rates': results['history']['learning_rates'], | |
| } | |
| with open(history_path, 'w') as f: | |
| json.dump(history_save, f, indent=2) | |
| logger.info(f"\nResults saved to {save_dir}") | |
| return results | |
| # ============================================================================= | |
| # Main Entry Point | |
| # ============================================================================= | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description='Train deep learning models for AMR prediction' | |
| ) | |
| parser.add_argument( | |
| '--task', | |
| type=str, | |
| choices=['organism', 'amr', 'both'], | |
| default='both', | |
| help='Task to train: organism, amr, or both', | |
| ) | |
| parser.add_argument('--epochs', type=int, default=100, help='Number of epochs') | |
| parser.add_argument('--batch-size', type=int, default=32, help='Batch size') | |
| parser.add_argument('--lr', type=float, default=0.001, help='Learning rate') | |
| parser.add_argument('--save-dir', type=str, default='models', help='Save directory') | |
| args = parser.parse_args() | |
| # Create save directory | |
| Path(args.save_dir).mkdir(parents=True, exist_ok=True) | |
| # Training timestamp | |
| timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') | |
| logger.info(f"Training started at {timestamp}") | |
| results = {} | |
| if args.task in ['organism', 'both']: | |
| results['organism'] = train_organism_classifier( | |
| epochs=args.epochs, | |
| batch_size=args.batch_size, | |
| learning_rate=args.lr, | |
| save_dir=args.save_dir, | |
| ) | |
| if args.task in ['amr', 'both']: | |
| results['amr'] = train_amr_predictor( | |
| epochs=args.epochs, | |
| batch_size=args.batch_size, | |
| learning_rate=args.lr, | |
| save_dir=args.save_dir, | |
| ) | |
| logger.info("\n" + "=" * 60) | |
| logger.info("TRAINING COMPLETE") | |
| logger.info("=" * 60) | |
| if 'organism' in results: | |
| logger.info(f"\nOrganism Classification Test Accuracy: " | |
| f"{results['organism']['test_metrics']['accuracy']:.4f}") | |
| if 'amr' in results: | |
| logger.info(f"\nAMR Prediction Test F1 (macro): " | |
| f"{results['amr']['test_metrics']['f1_macro']:.4f}") | |
| logger.info(f"\nModels saved to: {args.save_dir}/") | |
| if __name__ == '__main__': | |
| main() | |