deepamr-api / src /ml /amr_classifier.py
hossainlab's picture
Deploy DeepAMR API backend
3255634
# src/ml/amr_classifier.py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
import joblib
class AMRDataset(Dataset):
"""PyTorch Dataset for AMR prediction"""
def __init__(self, features, labels):
self.features = torch.FloatTensor(features)
self.labels = torch.FloatTensor(labels)
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
return self.features[idx], self.labels[idx]
class AMRClassifier(nn.Module):
"""Neural network for AMR prediction"""
def __init__(self, input_dim=370, hidden_dims=[512, 256, 128], dropout=0.3):
super(AMRClassifier, self).__init__()
layers = []
prev_dim = input_dim
for hidden_dim in hidden_dims:
layers.append(nn.Linear(prev_dim, hidden_dim))
layers.append(nn.BatchNorm1d(hidden_dim))
layers.append(nn.ReLU())
layers.append(nn.Dropout(dropout))
prev_dim = hidden_dim
# Output layer
layers.append(nn.Linear(prev_dim, 1))
layers.append(nn.Sigmoid())
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class AMRModelTrainer:
"""Train AMR prediction models"""
def __init__(self, feature_extractor, device='cuda'):
self.feature_extractor = feature_extractor
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
self.models = {} # One model per antibiotic
def prepare_dataset(self, data_csv='data/processed/training_data.csv'):
"""Prepare features from genome sequences"""
df = pd.read_csv(data_csv)
print("Extracting features from genomes...")
features_list = []
labels_list = []
antibiotics_list = []
for idx, row in df.iterrows():
if idx % 10 == 0:
print(f"Processing {idx}/{len(df)}")
try:
# Extract features
genome_path = row['genome_path']
feature_dict = self.feature_extractor.extract_features(genome_path)
features_list.append(feature_dict['features'])
labels_list.append(row['resistance'])
antibiotics_list.append(row['antibiotic'])
except Exception as e:
print(f"Error processing {row['sample_id']}: {e}")
continue
# Save processed features
processed_data = {
'features': np.array(features_list),
'labels': np.array(labels_list),
'antibiotics': antibiotics_list
}
joblib.dump(processed_data, 'data/processed/extracted_features.pkl')
print(f"Saved {len(features_list)} processed samples")
return processed_data
def train_model_for_antibiotic(self, antibiotic: str, X, y, epochs=50, batch_size=32):
"""Train a model for specific antibiotic"""
print(f"\n{'='*60}")
print(f"Training model for {antibiotic}")
print(f"{'='*60}")
# Split data
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
print(f"Training samples: {len(X_train)}, Test samples: {len(X_test)}")
print(f"Resistance ratio - Train: {y_train.mean():.2f}, Test: {y_test.mean():.2f}")
# Create datasets
train_dataset = AMRDataset(X_train, y_train)
test_dataset = AMRDataset(X_test, y_test)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# Initialize model
model = AMRClassifier(input_dim=X.shape[1]).to(self.device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
# Training loop
best_auc = 0
for epoch in range(epochs):
# Train
model.train()
train_loss = 0
for features, labels in train_loader:
features = features.to(self.device)
labels = labels.to(self.device).unsqueeze(1)
optimizer.zero_grad()
outputs = model(features)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
# Evaluate
model.eval()
test_predictions = []
test_labels = []
with torch.no_grad():
for features, labels in test_loader:
features = features.to(self.device)
outputs = model(features)
test_predictions.extend(outputs.cpu().numpy())
test_labels.extend(labels.numpy())
# Calculate metrics
test_predictions = np.array(test_predictions)
test_labels = np.array(test_labels)
test_auc = roc_auc_score(test_labels, test_predictions)
scheduler.step(train_loss)
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}/{epochs} - Loss: {train_loss/len(train_loader):.4f}, AUC: {test_auc:.4f}")
# Save best model
if test_auc > best_auc:
best_auc = test_auc
torch.save(model.state_dict(), f'models/checkpoints/{antibiotic}_best.pth')
# Final evaluation
print(f"\nFinal Results for {antibiotic}:")
print(f"Best AUC: {best_auc:.4f}")
# Binary predictions
binary_preds = (test_predictions > 0.5).astype(int).flatten()
print("\nClassification Report:")
print(classification_report(test_labels, binary_preds,
target_names=['Susceptible', 'Resistant']))
self.models[antibiotic] = model
return model, best_auc
def train_all_antibiotics(self):
"""Train models for all antibiotics"""
# Load processed features
data = joblib.load('data/processed/extracted_features.pkl')
features = data['features']
labels = data['labels']
antibiotics = data['antibiotics']
# Get unique antibiotics
unique_antibiotics = list(set(antibiotics))
results = {}
for antibiotic in unique_antibiotics:
# Filter data for this antibiotic
mask = [ab == antibiotic for ab in antibiotics]
X_ab = features[mask]
y_ab = labels[mask]
# Check if we have enough samples
if len(X_ab) < 50:
print(f"Skipping {antibiotic} - insufficient data ({len(X_ab)} samples)")
continue
# Train model
model, auc = self.train_model_for_antibiotic(antibiotic, X_ab, y_ab)
results[antibiotic] = auc
# Save results summary
results_df = pd.DataFrame.from_dict(results, orient='index', columns=['AUC'])
results_df.to_csv('models/training_results.csv')
print("\n" + "="*60)
print("Training Complete! Results:")
print(results_df)
return results
# Training script
if __name__ == "__main__":
from feature_extractor import CombinedFeatureExtractor
# Initialize
feature_extractor = CombinedFeatureExtractor()
trainer = AMRModelTrainer(feature_extractor)
# Step 1: Extract features (only need to do once)
# trainer.prepare_dataset()
# Step 2: Train models
results = trainer.train_all_antibiotics()