File size: 8,317 Bytes
3255634
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
# src/ml/amr_classifier.py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
import joblib

class AMRDataset(Dataset):
    """PyTorch Dataset for AMR prediction"""
    
    def __init__(self, features, labels):
        self.features = torch.FloatTensor(features)
        self.labels = torch.FloatTensor(labels)
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]


class AMRClassifier(nn.Module):
    """Neural network for AMR prediction"""
    
    def __init__(self, input_dim=370, hidden_dims=[512, 256, 128], dropout=0.3):
        super(AMRClassifier, self).__init__()
        
        layers = []
        prev_dim = input_dim
        
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, hidden_dim))
            layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout))
            prev_dim = hidden_dim
        
        # Output layer
        layers.append(nn.Linear(prev_dim, 1))
        layers.append(nn.Sigmoid())
        
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)


class AMRModelTrainer:
    """Train AMR prediction models"""
    
    def __init__(self, feature_extractor, device='cuda'):
        self.feature_extractor = feature_extractor
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        self.models = {}  # One model per antibiotic
        
    def prepare_dataset(self, data_csv='data/processed/training_data.csv'):
        """Prepare features from genome sequences"""
        df = pd.read_csv(data_csv)
        
        print("Extracting features from genomes...")
        features_list = []
        labels_list = []
        antibiotics_list = []
        
        for idx, row in df.iterrows():
            if idx % 10 == 0:
                print(f"Processing {idx}/{len(df)}")
            
            try:
                # Extract features
                genome_path = row['genome_path']
                feature_dict = self.feature_extractor.extract_features(genome_path)
                
                features_list.append(feature_dict['features'])
                labels_list.append(row['resistance'])
                antibiotics_list.append(row['antibiotic'])
                
            except Exception as e:
                print(f"Error processing {row['sample_id']}: {e}")
                continue
        
        # Save processed features
        processed_data = {
            'features': np.array(features_list),
            'labels': np.array(labels_list),
            'antibiotics': antibiotics_list
        }
        
        joblib.dump(processed_data, 'data/processed/extracted_features.pkl')
        print(f"Saved {len(features_list)} processed samples")
        
        return processed_data
    
    def train_model_for_antibiotic(self, antibiotic: str, X, y, epochs=50, batch_size=32):
        """Train a model for specific antibiotic"""
        print(f"\n{'='*60}")
        print(f"Training model for {antibiotic}")
        print(f"{'='*60}")
        
        # Split data
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=0.2, random_state=42, stratify=y
        )
        
        print(f"Training samples: {len(X_train)}, Test samples: {len(X_test)}")
        print(f"Resistance ratio - Train: {y_train.mean():.2f}, Test: {y_test.mean():.2f}")
        
        # Create datasets
        train_dataset = AMRDataset(X_train, y_train)
        test_dataset = AMRDataset(X_test, y_test)
        
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
        
        # Initialize model
        model = AMRClassifier(input_dim=X.shape[1]).to(self.device)
        criterion = nn.BCELoss()
        optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
        
        # Training loop
        best_auc = 0
        for epoch in range(epochs):
            # Train
            model.train()
            train_loss = 0
            for features, labels in train_loader:
                features = features.to(self.device)
                labels = labels.to(self.device).unsqueeze(1)
                
                optimizer.zero_grad()
                outputs = model(features)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                
                train_loss += loss.item()
            
            # Evaluate
            model.eval()
            test_predictions = []
            test_labels = []
            
            with torch.no_grad():
                for features, labels in test_loader:
                    features = features.to(self.device)
                    outputs = model(features)
                    test_predictions.extend(outputs.cpu().numpy())
                    test_labels.extend(labels.numpy())
            
            # Calculate metrics
            test_predictions = np.array(test_predictions)
            test_labels = np.array(test_labels)
            test_auc = roc_auc_score(test_labels, test_predictions)
            
            scheduler.step(train_loss)
            
            if (epoch + 1) % 10 == 0:
                print(f"Epoch {epoch+1}/{epochs} - Loss: {train_loss/len(train_loader):.4f}, AUC: {test_auc:.4f}")
            
            # Save best model
            if test_auc > best_auc:
                best_auc = test_auc
                torch.save(model.state_dict(), f'models/checkpoints/{antibiotic}_best.pth')
        
        # Final evaluation
        print(f"\nFinal Results for {antibiotic}:")
        print(f"Best AUC: {best_auc:.4f}")
        
        # Binary predictions
        binary_preds = (test_predictions > 0.5).astype(int).flatten()
        print("\nClassification Report:")
        print(classification_report(test_labels, binary_preds, 
                                    target_names=['Susceptible', 'Resistant']))
        
        self.models[antibiotic] = model
        return model, best_auc
    
    def train_all_antibiotics(self):
        """Train models for all antibiotics"""
        # Load processed features
        data = joblib.load('data/processed/extracted_features.pkl')
        
        features = data['features']
        labels = data['labels']
        antibiotics = data['antibiotics']
        
        # Get unique antibiotics
        unique_antibiotics = list(set(antibiotics))
        
        results = {}
        for antibiotic in unique_antibiotics:
            # Filter data for this antibiotic
            mask = [ab == antibiotic for ab in antibiotics]
            X_ab = features[mask]
            y_ab = labels[mask]
            
            # Check if we have enough samples
            if len(X_ab) < 50:
                print(f"Skipping {antibiotic} - insufficient data ({len(X_ab)} samples)")
                continue
            
            # Train model
            model, auc = self.train_model_for_antibiotic(antibiotic, X_ab, y_ab)
            results[antibiotic] = auc
        
        # Save results summary
        results_df = pd.DataFrame.from_dict(results, orient='index', columns=['AUC'])
        results_df.to_csv('models/training_results.csv')
        print("\n" + "="*60)
        print("Training Complete! Results:")
        print(results_df)
        
        return results


# Training script
if __name__ == "__main__":
    from feature_extractor import CombinedFeatureExtractor
    
    # Initialize
    feature_extractor = CombinedFeatureExtractor()
    trainer = AMRModelTrainer(feature_extractor)
    
    # Step 1: Extract features (only need to do once)
    # trainer.prepare_dataset()
    
    # Step 2: Train models
    results = trainer.train_all_antibiotics()