deepamr-api / src /ml /deep_learning_trainer.py
hossainlab's picture
Deploy DeepAMR API backend
3255634
#!/usr/bin/env python3
"""
DeepAMR: Deep Learning Models for Antimicrobial Resistance Prediction
This script trains deep learning models for:
1. Organism Classification (multiclass)
2. AMR Drug Resistance Prediction (multilabel)
Designed for high-impact deployment in Bangladesh healthcare systems.
Usage:
python src/ml/deep_learning_trainer.py --task organism
python src/ml/deep_learning_trainer.py --task amr
python src/ml/deep_learning_trainer.py --task both
"""
import argparse
import json
import logging
import os
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import (
accuracy_score,
classification_report,
f1_score,
precision_score,
recall_score,
roc_auc_score,
confusion_matrix,
)
from sklearn.preprocessing import StandardScaler
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Set device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else
'mps' if torch.backends.mps.is_available() else 'cpu')
logger.info(f"Using device: {DEVICE}")
# =============================================================================
# Neural Network Architectures
# =============================================================================
class OrganismClassifier(nn.Module):
"""Deep neural network for organism classification from k-mer features."""
def __init__(
self,
input_size: int,
hidden_sizes: List[int] = [256, 128, 64],
num_classes: int = 8,
dropout_rate: float = 0.3,
):
super().__init__()
layers = []
prev_size = input_size
for hidden_size in hidden_sizes:
layers.extend([
nn.Linear(prev_size, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(),
nn.Dropout(dropout_rate),
])
prev_size = hidden_size
layers.append(nn.Linear(prev_size, num_classes))
self.network = nn.Sequential(*layers)
# Initialize weights
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.network(x)
class AMRPredictor(nn.Module):
"""Deep neural network for multi-label AMR prediction."""
def __init__(
self,
input_size: int,
hidden_sizes: List[int] = [512, 256, 128],
num_classes: int = 11,
dropout_rate: float = 0.4,
):
super().__init__()
# Shared feature extractor
shared_layers = []
prev_size = input_size
for i, hidden_size in enumerate(hidden_sizes[:-1]):
shared_layers.extend([
nn.Linear(prev_size, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.LeakyReLU(0.1),
nn.Dropout(dropout_rate),
])
prev_size = hidden_size
self.shared = nn.Sequential(*shared_layers)
# Drug-class specific heads for better performance
self.drug_heads = nn.ModuleList([
nn.Sequential(
nn.Linear(prev_size, hidden_sizes[-1]),
nn.BatchNorm1d(hidden_sizes[-1]),
nn.LeakyReLU(0.1),
nn.Dropout(dropout_rate * 0.5),
nn.Linear(hidden_sizes[-1], 1),
)
for _ in range(num_classes)
])
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
shared_features = self.shared(x)
outputs = [head(shared_features) for head in self.drug_heads]
return torch.cat(outputs, dim=1)
class ResidualBlock(nn.Module):
"""Residual block for deeper networks."""
def __init__(self, size: int, dropout_rate: float = 0.3):
super().__init__()
self.block = nn.Sequential(
nn.Linear(size, size),
nn.BatchNorm1d(size),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(size, size),
nn.BatchNorm1d(size),
)
self.relu = nn.ReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.relu(x + self.block(x))
class DeepAMRNet(nn.Module):
"""Advanced deep network with residual connections for AMR prediction."""
def __init__(
self,
input_size: int,
hidden_size: int = 256,
num_residual_blocks: int = 3,
num_classes: int = 11,
dropout_rate: float = 0.3,
):
super().__init__()
self.input_layer = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(),
nn.Dropout(dropout_rate),
)
self.residual_blocks = nn.Sequential(
*[ResidualBlock(hidden_size, dropout_rate) for _ in range(num_residual_blocks)]
)
self.output_layer = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 2),
nn.BatchNorm1d(hidden_size // 2),
nn.ReLU(),
nn.Dropout(dropout_rate * 0.5),
nn.Linear(hidden_size // 2, num_classes),
)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.input_layer(x)
x = self.residual_blocks(x)
return self.output_layer(x)
# =============================================================================
# Training Utilities
# =============================================================================
class EarlyStopping:
"""Early stopping to prevent overfitting."""
def __init__(self, patience: int = 10, min_delta: float = 0.001, mode: str = 'min'):
self.patience = patience
self.min_delta = min_delta
self.mode = mode
self.counter = 0
self.best_score = None
self.early_stop = False
def __call__(self, score: float) -> bool:
if self.best_score is None:
self.best_score = score
elif self._is_improvement(score):
self.best_score = score
self.counter = 0
else:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
return self.early_stop
def _is_improvement(self, score: float) -> bool:
if self.mode == 'min':
return score < self.best_score - self.min_delta
return score > self.best_score + self.min_delta
class FocalLoss(nn.Module):
"""Focal Loss for handling class imbalance in multilabel classification."""
def __init__(self, alpha: float = 0.25, gamma: float = 2.0):
super().__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
bce_loss = nn.functional.binary_cross_entropy_with_logits(
inputs, targets, reduction='none'
)
pt = torch.exp(-bce_loss)
focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss
return focal_loss.mean()
def compute_class_weights(y: np.ndarray, task: str = 'multiclass') -> torch.Tensor:
"""Compute class weights to handle imbalanced data."""
if task == 'multiclass':
class_counts = np.bincount(y)
total = len(y)
weights = total / (len(class_counts) * class_counts)
return torch.FloatTensor(weights)
else: # multilabel
pos_counts = y.sum(axis=0)
neg_counts = len(y) - pos_counts
weights = neg_counts / (pos_counts + 1e-6)
weights = np.clip(weights, 1.0, 10.0) # Clip extreme weights
return torch.FloatTensor(weights)
# =============================================================================
# Trainer Classes
# =============================================================================
class BaseTrainer:
"""Base trainer class with common functionality."""
def __init__(
self,
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
test_loader: DataLoader,
learning_rate: float = 0.001,
weight_decay: float = 0.01,
device: torch.device = DEVICE,
):
self.model = model.to(device)
self.train_loader = train_loader
self.val_loader = val_loader
self.test_loader = test_loader
self.device = device
self.optimizer = optim.AdamW(
model.parameters(),
lr=learning_rate,
weight_decay=weight_decay,
)
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer,
mode='min',
factor=0.5,
patience=5,
)
self.history = {
'train_loss': [],
'val_loss': [],
'train_metrics': [],
'val_metrics': [],
'learning_rates': [],
}
def save_checkpoint(self, path: str, epoch: int, metrics: Dict):
"""Save model checkpoint."""
checkpoint = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'metrics': metrics,
'history': self.history,
}
torch.save(checkpoint, path)
logger.info(f"Checkpoint saved to {path}")
def load_checkpoint(self, path: str):
"""Load model checkpoint."""
checkpoint = torch.load(path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
self.history = checkpoint['history']
return checkpoint['epoch'], checkpoint['metrics']
class OrganismTrainer(BaseTrainer):
"""Trainer for organism classification."""
def __init__(
self,
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
test_loader: DataLoader,
class_weights: Optional[torch.Tensor] = None,
**kwargs,
):
super().__init__(model, train_loader, val_loader, test_loader, **kwargs)
if class_weights is not None:
class_weights = class_weights.to(self.device)
self.criterion = nn.CrossEntropyLoss(weight=class_weights)
def train_epoch(self) -> Tuple[float, Dict]:
"""Train for one epoch."""
self.model.train()
total_loss = 0
all_preds = []
all_labels = []
for batch_x, batch_y in self.train_loader:
batch_x = batch_x.to(self.device)
batch_y = batch_y.to(self.device)
self.optimizer.zero_grad()
outputs = self.model(batch_x)
loss = self.criterion(outputs, batch_y)
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
total_loss += loss.item()
preds = outputs.argmax(dim=1).cpu().numpy()
all_preds.extend(preds)
all_labels.extend(batch_y.cpu().numpy())
avg_loss = total_loss / len(self.train_loader)
metrics = {
'accuracy': accuracy_score(all_labels, all_preds),
'f1_macro': f1_score(all_labels, all_preds, average='macro'),
'f1_weighted': f1_score(all_labels, all_preds, average='weighted'),
}
return avg_loss, metrics
def validate(self, loader: DataLoader) -> Tuple[float, Dict]:
"""Validate the model."""
self.model.eval()
total_loss = 0
all_preds = []
all_labels = []
all_probs = []
with torch.no_grad():
for batch_x, batch_y in loader:
batch_x = batch_x.to(self.device)
batch_y = batch_y.to(self.device)
outputs = self.model(batch_x)
loss = self.criterion(outputs, batch_y)
total_loss += loss.item()
probs = torch.softmax(outputs, dim=1).cpu().numpy()
preds = outputs.argmax(dim=1).cpu().numpy()
all_probs.extend(probs)
all_preds.extend(preds)
all_labels.extend(batch_y.cpu().numpy())
avg_loss = total_loss / len(loader)
all_probs = np.array(all_probs)
metrics = {
'accuracy': accuracy_score(all_labels, all_preds),
'f1_macro': f1_score(all_labels, all_preds, average='macro'),
'f1_weighted': f1_score(all_labels, all_preds, average='weighted'),
'precision_macro': precision_score(all_labels, all_preds, average='macro'),
'recall_macro': recall_score(all_labels, all_preds, average='macro'),
}
# ROC-AUC for multiclass
try:
metrics['roc_auc'] = roc_auc_score(
all_labels, all_probs, multi_class='ovr', average='macro'
)
except ValueError:
metrics['roc_auc'] = 0.0
return avg_loss, metrics, all_preds, all_labels, all_probs
def train(
self,
epochs: int = 100,
patience: int = 15,
save_path: str = 'models/organism_classifier.pt',
) -> Dict:
"""Full training loop."""
early_stopping = EarlyStopping(patience=patience, mode='max')
best_f1 = 0
logger.info("Starting organism classification training...")
logger.info(f"Training samples: {len(self.train_loader.dataset)}")
logger.info(f"Validation samples: {len(self.val_loader.dataset)}")
for epoch in range(epochs):
# Train
train_loss, train_metrics = self.train_epoch()
# Validate
val_loss, val_metrics, _, _, _ = self.validate(self.val_loader)
# Update scheduler
self.scheduler.step(val_loss)
# Record history
self.history['train_loss'].append(train_loss)
self.history['val_loss'].append(val_loss)
self.history['train_metrics'].append(train_metrics)
self.history['val_metrics'].append(val_metrics)
self.history['learning_rates'].append(self.optimizer.param_groups[0]['lr'])
# Logging
logger.info(
f"Epoch {epoch+1}/{epochs} | "
f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | "
f"Train Acc: {train_metrics['accuracy']:.4f} | "
f"Val Acc: {val_metrics['accuracy']:.4f} | "
f"Val F1: {val_metrics['f1_macro']:.4f}"
)
# Save best model
if val_metrics['f1_macro'] > best_f1:
best_f1 = val_metrics['f1_macro']
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
self.save_checkpoint(save_path, epoch, val_metrics)
logger.info(f"New best model saved! F1: {best_f1:.4f}")
# Early stopping
if early_stopping(val_metrics['f1_macro']):
logger.info(f"Early stopping triggered at epoch {epoch+1}")
break
# Final evaluation on test set
logger.info("\nEvaluating on test set...")
test_loss, test_metrics, test_preds, test_labels, test_probs = self.validate(
self.test_loader
)
logger.info(f"\nTest Results:")
logger.info(f" Accuracy: {test_metrics['accuracy']:.4f}")
logger.info(f" F1 (macro): {test_metrics['f1_macro']:.4f}")
logger.info(f" F1 (weighted): {test_metrics['f1_weighted']:.4f}")
logger.info(f" ROC-AUC: {test_metrics['roc_auc']:.4f}")
return {
'history': self.history,
'test_metrics': test_metrics,
'test_predictions': test_preds,
'test_labels': test_labels,
'test_probabilities': test_probs,
}
class AMRTrainer(BaseTrainer):
"""Trainer for multilabel AMR prediction."""
def __init__(
self,
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
test_loader: DataLoader,
pos_weights: Optional[torch.Tensor] = None,
use_focal_loss: bool = True,
**kwargs,
):
super().__init__(model, train_loader, val_loader, test_loader, **kwargs)
if use_focal_loss:
self.criterion = FocalLoss(alpha=0.25, gamma=2.0)
else:
if pos_weights is not None:
pos_weights = pos_weights.to(self.device)
self.criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weights)
def train_epoch(self) -> Tuple[float, Dict]:
"""Train for one epoch."""
self.model.train()
total_loss = 0
all_preds = []
all_labels = []
for batch_x, batch_y in self.train_loader:
batch_x = batch_x.to(self.device)
batch_y = batch_y.float().to(self.device)
self.optimizer.zero_grad()
outputs = self.model(batch_x)
loss = self.criterion(outputs, batch_y)
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
total_loss += loss.item()
preds = (torch.sigmoid(outputs) > 0.5).cpu().numpy()
all_preds.extend(preds)
all_labels.extend(batch_y.cpu().numpy())
avg_loss = total_loss / len(self.train_loader)
all_preds = np.array(all_preds)
all_labels = np.array(all_labels)
metrics = {
'f1_micro': f1_score(all_labels, all_preds, average='micro'),
'f1_macro': f1_score(all_labels, all_preds, average='macro'),
'f1_samples': f1_score(all_labels, all_preds, average='samples'),
}
return avg_loss, metrics
def validate(self, loader: DataLoader) -> Tuple[float, Dict]:
"""Validate the model."""
self.model.eval()
total_loss = 0
all_preds = []
all_labels = []
all_probs = []
with torch.no_grad():
for batch_x, batch_y in loader:
batch_x = batch_x.to(self.device)
batch_y = batch_y.float().to(self.device)
outputs = self.model(batch_x)
loss = self.criterion(outputs, batch_y)
total_loss += loss.item()
probs = torch.sigmoid(outputs).cpu().numpy()
preds = (probs > 0.5).astype(int)
all_probs.extend(probs)
all_preds.extend(preds)
all_labels.extend(batch_y.cpu().numpy())
avg_loss = total_loss / len(loader)
all_preds = np.array(all_preds)
all_labels = np.array(all_labels)
all_probs = np.array(all_probs)
metrics = {
'f1_micro': f1_score(all_labels, all_preds, average='micro'),
'f1_macro': f1_score(all_labels, all_preds, average='macro'),
'f1_samples': f1_score(all_labels, all_preds, average='samples'),
'precision_micro': precision_score(all_labels, all_preds, average='micro'),
'recall_micro': recall_score(all_labels, all_preds, average='micro'),
}
# Per-class metrics
per_class_f1 = f1_score(all_labels, all_preds, average=None)
metrics['per_class_f1'] = per_class_f1.tolist()
# ROC-AUC
try:
metrics['roc_auc_micro'] = roc_auc_score(all_labels, all_probs, average='micro')
metrics['roc_auc_macro'] = roc_auc_score(all_labels, all_probs, average='macro')
except ValueError:
metrics['roc_auc_micro'] = 0.0
metrics['roc_auc_macro'] = 0.0
return avg_loss, metrics, all_preds, all_labels, all_probs
def train(
self,
epochs: int = 100,
patience: int = 15,
save_path: str = 'models/amr_predictor.pt',
) -> Dict:
"""Full training loop."""
early_stopping = EarlyStopping(patience=patience, mode='max')
best_f1 = 0
logger.info("Starting AMR prediction training...")
logger.info(f"Training samples: {len(self.train_loader.dataset)}")
logger.info(f"Validation samples: {len(self.val_loader.dataset)}")
for epoch in range(epochs):
# Train
train_loss, train_metrics = self.train_epoch()
# Validate
val_loss, val_metrics, _, _, _ = self.validate(self.val_loader)
# Update scheduler
self.scheduler.step(val_loss)
# Record history
self.history['train_loss'].append(train_loss)
self.history['val_loss'].append(val_loss)
self.history['train_metrics'].append(train_metrics)
self.history['val_metrics'].append(val_metrics)
self.history['learning_rates'].append(self.optimizer.param_groups[0]['lr'])
# Logging
logger.info(
f"Epoch {epoch+1}/{epochs} | "
f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | "
f"Train F1: {train_metrics['f1_macro']:.4f} | "
f"Val F1: {val_metrics['f1_macro']:.4f} | "
f"Val AUC: {val_metrics.get('roc_auc_macro', 0):.4f}"
)
# Save best model
if val_metrics['f1_macro'] > best_f1:
best_f1 = val_metrics['f1_macro']
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
self.save_checkpoint(save_path, epoch, val_metrics)
logger.info(f"New best model saved! F1: {best_f1:.4f}")
# Early stopping
if early_stopping(val_metrics['f1_macro']):
logger.info(f"Early stopping triggered at epoch {epoch+1}")
break
# Final evaluation on test set
logger.info("\nEvaluating on test set...")
test_loss, test_metrics, test_preds, test_labels, test_probs = self.validate(
self.test_loader
)
logger.info(f"\nTest Results:")
logger.info(f" F1 (micro): {test_metrics['f1_micro']:.4f}")
logger.info(f" F1 (macro): {test_metrics['f1_macro']:.4f}")
logger.info(f" F1 (samples): {test_metrics['f1_samples']:.4f}")
logger.info(f" ROC-AUC (macro): {test_metrics['roc_auc_macro']:.4f}")
return {
'history': self.history,
'test_metrics': test_metrics,
'test_predictions': test_preds,
'test_labels': test_labels,
'test_probabilities': test_probs,
}
# =============================================================================
# Data Loading
# =============================================================================
def load_data(task: str = 'organism') -> Tuple:
"""Load preprocessed data for training."""
data_dir = Path('data/processed/ncbi')
if task == 'organism':
prefix = 'ncbi_organism'
else:
prefix = 'ncbi_amr'
X_train = np.load(data_dir / f'{prefix}_X_train.npy')
X_val = np.load(data_dir / f'{prefix}_X_val.npy')
X_test = np.load(data_dir / f'{prefix}_X_test.npy')
y_train = np.load(data_dir / f'{prefix}_y_train.npy')
y_val = np.load(data_dir / f'{prefix}_y_val.npy')
y_test = np.load(data_dir / f'{prefix}_y_test.npy')
with open(data_dir / f'{prefix}_metadata.json') as f:
metadata = json.load(f)
logger.info(f"Loaded {task} data:")
logger.info(f" Train: {X_train.shape}, Val: {X_val.shape}, Test: {X_test.shape}")
return X_train, X_val, X_test, y_train, y_val, y_test, metadata
def create_dataloaders(
X_train: np.ndarray,
X_val: np.ndarray,
X_test: np.ndarray,
y_train: np.ndarray,
y_val: np.ndarray,
y_test: np.ndarray,
batch_size: int = 32,
normalize: bool = True,
) -> Tuple[DataLoader, DataLoader, DataLoader, Optional[StandardScaler]]:
"""Create PyTorch DataLoaders."""
scaler = None
if normalize:
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_val = scaler.transform(X_val)
X_test = scaler.transform(X_test)
train_dataset = TensorDataset(
torch.FloatTensor(X_train),
torch.LongTensor(y_train) if y_train.ndim == 1 else torch.FloatTensor(y_train),
)
val_dataset = TensorDataset(
torch.FloatTensor(X_val),
torch.LongTensor(y_val) if y_val.ndim == 1 else torch.FloatTensor(y_val),
)
test_dataset = TensorDataset(
torch.FloatTensor(X_test),
torch.LongTensor(y_test) if y_test.ndim == 1 else torch.FloatTensor(y_test),
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
return train_loader, val_loader, test_loader, scaler
# =============================================================================
# Main Training Functions
# =============================================================================
def train_organism_classifier(
epochs: int = 100,
batch_size: int = 32,
learning_rate: float = 0.001,
hidden_sizes: List[int] = [256, 128, 64],
dropout_rate: float = 0.3,
save_dir: str = 'models',
) -> Dict:
"""Train organism classification model."""
logger.info("=" * 60)
logger.info("ORGANISM CLASSIFICATION TRAINING")
logger.info("=" * 60)
# Load data
X_train, X_val, X_test, y_train, y_val, y_test, metadata = load_data('organism')
# Create dataloaders
train_loader, val_loader, test_loader, scaler = create_dataloaders(
X_train, X_val, X_test, y_train, y_val, y_test, batch_size
)
# Compute class weights
class_weights = compute_class_weights(y_train, 'multiclass')
# Create model
model = OrganismClassifier(
input_size=X_train.shape[1],
hidden_sizes=hidden_sizes,
num_classes=len(metadata['class_names']),
dropout_rate=dropout_rate,
)
logger.info(f"Model architecture:\n{model}")
logger.info(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
# Create trainer
trainer = OrganismTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
test_loader=test_loader,
class_weights=class_weights,
learning_rate=learning_rate,
)
# Train
save_path = Path(save_dir) / 'organism_classifier.pt'
results = trainer.train(epochs=epochs, save_path=str(save_path))
# Save scaler
if scaler is not None:
import joblib
scaler_path = Path(save_dir) / 'organism_scaler.joblib'
joblib.dump(scaler, scaler_path)
logger.info(f"Scaler saved to {scaler_path}")
# Save metadata and results
results_path = Path(save_dir) / 'organism_results.json'
save_results = {
'metadata': metadata,
'test_metrics': results['test_metrics'],
'training_config': {
'epochs': epochs,
'batch_size': batch_size,
'learning_rate': learning_rate,
'hidden_sizes': hidden_sizes,
'dropout_rate': dropout_rate,
},
}
with open(results_path, 'w') as f:
json.dump(save_results, f, indent=2)
# Save training history
history_path = Path(save_dir) / 'organism_history.json'
history_save = {
'train_loss': results['history']['train_loss'],
'val_loss': results['history']['val_loss'],
'train_metrics': results['history']['train_metrics'],
'val_metrics': results['history']['val_metrics'],
'learning_rates': results['history']['learning_rates'],
}
with open(history_path, 'w') as f:
json.dump(history_save, f, indent=2)
logger.info(f"\nResults saved to {save_dir}")
return results
def train_amr_predictor(
epochs: int = 100,
batch_size: int = 32,
learning_rate: float = 0.001,
hidden_sizes: List[int] = [512, 256, 128],
dropout_rate: float = 0.4,
use_focal_loss: bool = True,
save_dir: str = 'models',
) -> Dict:
"""Train AMR prediction model."""
logger.info("=" * 60)
logger.info("AMR PREDICTION TRAINING")
logger.info("=" * 60)
# Load data
X_train, X_val, X_test, y_train, y_val, y_test, metadata = load_data('amr')
# Create dataloaders
train_loader, val_loader, test_loader, scaler = create_dataloaders(
X_train, X_val, X_test, y_train, y_val, y_test, batch_size
)
# Compute positive weights for class imbalance
pos_weights = compute_class_weights(y_train, 'multilabel')
# Create model
model = AMRPredictor(
input_size=X_train.shape[1],
hidden_sizes=hidden_sizes,
num_classes=len(metadata['class_names']),
dropout_rate=dropout_rate,
)
logger.info(f"Model architecture:\n{model}")
logger.info(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
# Create trainer
trainer = AMRTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
test_loader=test_loader,
pos_weights=pos_weights,
use_focal_loss=use_focal_loss,
learning_rate=learning_rate,
)
# Train
save_path = Path(save_dir) / 'amr_predictor.pt'
results = trainer.train(epochs=epochs, save_path=str(save_path))
# Save scaler
if scaler is not None:
import joblib
scaler_path = Path(save_dir) / 'amr_scaler.joblib'
joblib.dump(scaler, scaler_path)
logger.info(f"Scaler saved to {scaler_path}")
# Save metadata and results
results_path = Path(save_dir) / 'amr_results.json'
save_results = {
'metadata': metadata,
'test_metrics': {k: v if not isinstance(v, np.ndarray) else v.tolist()
for k, v in results['test_metrics'].items()},
'training_config': {
'epochs': epochs,
'batch_size': batch_size,
'learning_rate': learning_rate,
'hidden_sizes': hidden_sizes,
'dropout_rate': dropout_rate,
'use_focal_loss': use_focal_loss,
},
}
with open(results_path, 'w') as f:
json.dump(save_results, f, indent=2)
# Save training history
history_path = Path(save_dir) / 'amr_history.json'
history_save = {
'train_loss': results['history']['train_loss'],
'val_loss': results['history']['val_loss'],
'train_metrics': results['history']['train_metrics'],
'val_metrics': [{k: v if not isinstance(v, list) else v
for k, v in m.items()} for m in results['history']['val_metrics']],
'learning_rates': results['history']['learning_rates'],
}
with open(history_path, 'w') as f:
json.dump(history_save, f, indent=2)
logger.info(f"\nResults saved to {save_dir}")
return results
# =============================================================================
# Main Entry Point
# =============================================================================
def main():
parser = argparse.ArgumentParser(
description='Train deep learning models for AMR prediction'
)
parser.add_argument(
'--task',
type=str,
choices=['organism', 'amr', 'both'],
default='both',
help='Task to train: organism, amr, or both',
)
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs')
parser.add_argument('--batch-size', type=int, default=32, help='Batch size')
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
parser.add_argument('--save-dir', type=str, default='models', help='Save directory')
args = parser.parse_args()
# Create save directory
Path(args.save_dir).mkdir(parents=True, exist_ok=True)
# Training timestamp
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
logger.info(f"Training started at {timestamp}")
results = {}
if args.task in ['organism', 'both']:
results['organism'] = train_organism_classifier(
epochs=args.epochs,
batch_size=args.batch_size,
learning_rate=args.lr,
save_dir=args.save_dir,
)
if args.task in ['amr', 'both']:
results['amr'] = train_amr_predictor(
epochs=args.epochs,
batch_size=args.batch_size,
learning_rate=args.lr,
save_dir=args.save_dir,
)
logger.info("\n" + "=" * 60)
logger.info("TRAINING COMPLETE")
logger.info("=" * 60)
if 'organism' in results:
logger.info(f"\nOrganism Classification Test Accuracy: "
f"{results['organism']['test_metrics']['accuracy']:.4f}")
if 'amr' in results:
logger.info(f"\nAMR Prediction Test F1 (macro): "
f"{results['amr']['test_metrics']['f1_macro']:.4f}")
logger.info(f"\nModels saved to: {args.save_dir}/")
if __name__ == '__main__':
main()