# src/ml/amr_classifier.py import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader import pandas as pd import numpy as np from sklearn.model_selection import train_test_split from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score import joblib class AMRDataset(Dataset): """PyTorch Dataset for AMR prediction""" def __init__(self, features, labels): self.features = torch.FloatTensor(features) self.labels = torch.FloatTensor(labels) def __len__(self): return len(self.labels) def __getitem__(self, idx): return self.features[idx], self.labels[idx] class AMRClassifier(nn.Module): """Neural network for AMR prediction""" def __init__(self, input_dim=370, hidden_dims=[512, 256, 128], dropout=0.3): super(AMRClassifier, self).__init__() layers = [] prev_dim = input_dim for hidden_dim in hidden_dims: layers.append(nn.Linear(prev_dim, hidden_dim)) layers.append(nn.BatchNorm1d(hidden_dim)) layers.append(nn.ReLU()) layers.append(nn.Dropout(dropout)) prev_dim = hidden_dim # Output layer layers.append(nn.Linear(prev_dim, 1)) layers.append(nn.Sigmoid()) self.model = nn.Sequential(*layers) def forward(self, x): return self.model(x) class AMRModelTrainer: """Train AMR prediction models""" def __init__(self, feature_extractor, device='cuda'): self.feature_extractor = feature_extractor self.device = torch.device(device if torch.cuda.is_available() else 'cpu') self.models = {} # One model per antibiotic def prepare_dataset(self, data_csv='data/processed/training_data.csv'): """Prepare features from genome sequences""" df = pd.read_csv(data_csv) print("Extracting features from genomes...") features_list = [] labels_list = [] antibiotics_list = [] for idx, row in df.iterrows(): if idx % 10 == 0: print(f"Processing {idx}/{len(df)}") try: # Extract features genome_path = row['genome_path'] feature_dict = self.feature_extractor.extract_features(genome_path) features_list.append(feature_dict['features']) labels_list.append(row['resistance']) antibiotics_list.append(row['antibiotic']) except Exception as e: print(f"Error processing {row['sample_id']}: {e}") continue # Save processed features processed_data = { 'features': np.array(features_list), 'labels': np.array(labels_list), 'antibiotics': antibiotics_list } joblib.dump(processed_data, 'data/processed/extracted_features.pkl') print(f"Saved {len(features_list)} processed samples") return processed_data def train_model_for_antibiotic(self, antibiotic: str, X, y, epochs=50, batch_size=32): """Train a model for specific antibiotic""" print(f"\n{'='*60}") print(f"Training model for {antibiotic}") print(f"{'='*60}") # Split data X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2, random_state=42, stratify=y ) print(f"Training samples: {len(X_train)}, Test samples: {len(X_test)}") print(f"Resistance ratio - Train: {y_train.mean():.2f}, Test: {y_test.mean():.2f}") # Create datasets train_dataset = AMRDataset(X_train, y_train) test_dataset = AMRDataset(X_test, y_test) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) # Initialize model model = AMRClassifier(input_dim=X.shape[1]).to(self.device) criterion = nn.BCELoss() optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5) # Training loop best_auc = 0 for epoch in range(epochs): # Train model.train() train_loss = 0 for features, labels in train_loader: features = features.to(self.device) labels = labels.to(self.device).unsqueeze(1) optimizer.zero_grad() outputs = model(features) loss = criterion(outputs, labels) loss.backward() optimizer.step() train_loss += loss.item() # Evaluate model.eval() test_predictions = [] test_labels = [] with torch.no_grad(): for features, labels in test_loader: features = features.to(self.device) outputs = model(features) test_predictions.extend(outputs.cpu().numpy()) test_labels.extend(labels.numpy()) # Calculate metrics test_predictions = np.array(test_predictions) test_labels = np.array(test_labels) test_auc = roc_auc_score(test_labels, test_predictions) scheduler.step(train_loss) if (epoch + 1) % 10 == 0: print(f"Epoch {epoch+1}/{epochs} - Loss: {train_loss/len(train_loader):.4f}, AUC: {test_auc:.4f}") # Save best model if test_auc > best_auc: best_auc = test_auc torch.save(model.state_dict(), f'models/checkpoints/{antibiotic}_best.pth') # Final evaluation print(f"\nFinal Results for {antibiotic}:") print(f"Best AUC: {best_auc:.4f}") # Binary predictions binary_preds = (test_predictions > 0.5).astype(int).flatten() print("\nClassification Report:") print(classification_report(test_labels, binary_preds, target_names=['Susceptible', 'Resistant'])) self.models[antibiotic] = model return model, best_auc def train_all_antibiotics(self): """Train models for all antibiotics""" # Load processed features data = joblib.load('data/processed/extracted_features.pkl') features = data['features'] labels = data['labels'] antibiotics = data['antibiotics'] # Get unique antibiotics unique_antibiotics = list(set(antibiotics)) results = {} for antibiotic in unique_antibiotics: # Filter data for this antibiotic mask = [ab == antibiotic for ab in antibiotics] X_ab = features[mask] y_ab = labels[mask] # Check if we have enough samples if len(X_ab) < 50: print(f"Skipping {antibiotic} - insufficient data ({len(X_ab)} samples)") continue # Train model model, auc = self.train_model_for_antibiotic(antibiotic, X_ab, y_ab) results[antibiotic] = auc # Save results summary results_df = pd.DataFrame.from_dict(results, orient='index', columns=['AUC']) results_df.to_csv('models/training_results.csv') print("\n" + "="*60) print("Training Complete! Results:") print(results_df) return results # Training script if __name__ == "__main__": from feature_extractor import CombinedFeatureExtractor # Initialize feature_extractor = CombinedFeatureExtractor() trainer = AMRModelTrainer(feature_extractor) # Step 1: Extract features (only need to do once) # trainer.prepare_dataset() # Step 2: Train models results = trainer.train_all_antibiotics()