cathrica commited on
Commit
9ea3b1c
·
verified ·
1 Parent(s): 1b1c554

Add training script for all 3 models

Browse files
Files changed (1) hide show
  1. experiments/train_baseline.py +282 -0
experiments/train_baseline.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training script for all three IDS models.
3
+ Trains MLP, LSTM, and 1D-CNN on NSL-KDD with full evaluation.
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import json
9
+ import time
10
+ import random
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.utils.data import TensorDataset, DataLoader
15
+ from sklearn.metrics import (classification_report, confusion_matrix,
16
+ roc_auc_score, average_precision_score)
17
+
18
+ # Add project root to path
19
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
20
+
21
+ from models.mlp_baseline import MLP_IDS
22
+ from models.lstm_model import LSTM_IDS
23
+ from models.cnn1d_model import CNN1D_IDS
24
+ from data.preprocess import load_nsl_kdd, preprocess, save_preprocessed, FEATURE_NAMES
25
+
26
+ # ========================
27
+ # Reproducibility
28
+ # ========================
29
+ SEED = 42
30
+ random.seed(SEED)
31
+ np.random.seed(SEED)
32
+ torch.manual_seed(SEED)
33
+ torch.backends.cudnn.deterministic = True
34
+ torch.backends.cudnn.benchmark = False
35
+
36
+ # ========================
37
+ # Config
38
+ # ========================
39
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
40
+ NUM_CLASSES = 2 # Binary classification
41
+ EPOCHS = 50
42
+ BATCH_SIZE = 256
43
+ LR = 1e-3
44
+ WEIGHT_DECAY = 1e-4
45
+ RESULTS_DIR = 'results'
46
+ MODELS_DIR = 'saved_models'
47
+
48
+
49
+ def compute_class_weights(y_train):
50
+ """Compute inverse-frequency class weights."""
51
+ counts = np.bincount(y_train)
52
+ weights = 1.0 / counts.astype(np.float32)
53
+ weights = weights / weights.sum() * len(weights) # Normalize
54
+ return torch.FloatTensor(weights).to(DEVICE)
55
+
56
+
57
+ def train_one_epoch(model, loader, criterion, optimizer):
58
+ """Train for one epoch."""
59
+ model.train()
60
+ total_loss = 0
61
+ correct = 0
62
+ total = 0
63
+
64
+ for X_batch, y_batch in loader:
65
+ X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
66
+
67
+ optimizer.zero_grad()
68
+ outputs = model(X_batch)
69
+ loss = criterion(outputs, y_batch)
70
+ loss.backward()
71
+ optimizer.step()
72
+
73
+ total_loss += loss.item() * len(y_batch)
74
+ preds = outputs.argmax(dim=1)
75
+ correct += (preds == y_batch).sum().item()
76
+ total += len(y_batch)
77
+
78
+ return total_loss / total, correct / total
79
+
80
+
81
+ @torch.no_grad()
82
+ def evaluate(model, loader, criterion):
83
+ """Evaluate model on dataset."""
84
+ model.eval()
85
+ total_loss = 0
86
+ all_preds = []
87
+ all_probs = []
88
+ all_labels = []
89
+
90
+ for X_batch, y_batch in loader:
91
+ X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
92
+
93
+ outputs = model(X_batch)
94
+ loss = criterion(outputs, y_batch)
95
+
96
+ total_loss += loss.item() * len(y_batch)
97
+ probs = torch.softmax(outputs, dim=1)
98
+ all_preds.append(outputs.argmax(dim=1).cpu().numpy())
99
+ all_probs.append(probs.cpu().numpy())
100
+ all_labels.append(y_batch.cpu().numpy())
101
+
102
+ all_preds = np.concatenate(all_preds)
103
+ all_probs = np.concatenate(all_probs)
104
+ all_labels = np.concatenate(all_labels)
105
+
106
+ avg_loss = total_loss / len(all_labels)
107
+
108
+ return avg_loss, all_preds, all_probs, all_labels
109
+
110
+
111
+ def full_evaluation(y_true, y_pred, y_probs, class_names):
112
+ """Compute all metrics."""
113
+ results = {}
114
+
115
+ # Classification report
116
+ report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
117
+ results['classification_report'] = report
118
+
119
+ # ROC-AUC (binary)
120
+ if len(class_names) == 2:
121
+ results['roc_auc'] = roc_auc_score(y_true, y_probs[:, 1])
122
+ results['pr_auc'] = average_precision_score(y_true, y_probs[:, 1])
123
+
124
+ # Confusion matrix
125
+ cm = confusion_matrix(y_true, y_pred)
126
+ results['confusion_matrix'] = cm.tolist()
127
+
128
+ return results
129
+
130
+
131
+ def train_model(model, model_name, X_train, y_train, X_test, y_test, class_names):
132
+ """Full training pipeline for one model."""
133
+ print(f"\n{'='*60}")
134
+ print(f"Training {model_name}")
135
+ print(f"{'='*60}")
136
+ print(f"Parameters: {model.count_parameters():,}")
137
+ print(f"Device: {DEVICE}")
138
+
139
+ # Data loaders
140
+ train_ds = TensorDataset(torch.FloatTensor(X_train), torch.LongTensor(y_train))
141
+ test_ds = TensorDataset(torch.FloatTensor(X_test), torch.LongTensor(y_test))
142
+ train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
143
+ test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)
144
+
145
+ # Loss with class weights
146
+ class_weights = compute_class_weights(y_train)
147
+ criterion = nn.CrossEntropyLoss(weight=class_weights)
148
+
149
+ # Optimizer
150
+ optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
151
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
152
+
153
+ # Training loop
154
+ model.to(DEVICE)
155
+ best_f1 = 0
156
+ history = {'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': []}
157
+
158
+ start_time = time.time()
159
+
160
+ for epoch in range(EPOCHS):
161
+ train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer)
162
+ test_loss, test_preds, test_probs, test_labels = evaluate(model, test_loader, criterion)
163
+ test_acc = (test_preds == test_labels).mean()
164
+
165
+ scheduler.step(test_loss)
166
+
167
+ history['train_loss'].append(train_loss)
168
+ history['train_acc'].append(train_acc)
169
+ history['test_loss'].append(test_loss)
170
+ history['test_acc'].append(test_acc)
171
+
172
+ # Check for best model
173
+ report = classification_report(test_labels, test_preds, output_dict=True)
174
+ weighted_f1 = report['weighted avg']['f1-score']
175
+
176
+ if weighted_f1 > best_f1:
177
+ best_f1 = weighted_f1
178
+ os.makedirs(MODELS_DIR, exist_ok=True)
179
+ torch.save(model.state_dict(), os.path.join(MODELS_DIR, f'{model_name}_best.pt'))
180
+
181
+ if (epoch + 1) % 10 == 0 or epoch == 0:
182
+ print(f" Epoch {epoch+1:3d}/{EPOCHS} | "
183
+ f"Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | "
184
+ f"Test Loss: {test_loss:.4f} Acc: {test_acc:.4f} F1: {weighted_f1:.4f}")
185
+
186
+ train_time = time.time() - start_time
187
+ print(f"\n Training time: {train_time:.1f}s")
188
+
189
+ # Load best model and final evaluation
190
+ model.load_state_dict(torch.load(os.path.join(MODELS_DIR, f'{model_name}_best.pt'),
191
+ weights_only=True))
192
+ _, final_preds, final_probs, final_labels = evaluate(model, test_loader, criterion)
193
+
194
+ results = full_evaluation(final_labels, final_preds, final_probs, class_names)
195
+ results['training_time'] = train_time
196
+ results['best_weighted_f1'] = best_f1
197
+ results['history'] = history
198
+ results['parameters'] = model.count_parameters()
199
+
200
+ # Print final results
201
+ print(f"\n Final Results ({model_name}):")
202
+ print(f" {'='*50}")
203
+ print(classification_report(final_labels, final_preds, target_names=class_names))
204
+
205
+ if 'roc_auc' in results:
206
+ print(f" ROC-AUC: {results['roc_auc']:.4f}")
207
+ print(f" PR-AUC: {results['pr_auc']:.4f}")
208
+
209
+ print(f" Confusion Matrix:\n{confusion_matrix(final_labels, final_preds)}")
210
+
211
+ return model, results
212
+
213
+
214
+ def main():
215
+ # ========================
216
+ # Data
217
+ # ========================
218
+ df_train, df_test = load_nsl_kdd()
219
+ X_train, X_test, y_train, y_test, le, scaler, class_names = preprocess(
220
+ df_train, df_test, binary=True
221
+ )
222
+ save_preprocessed(X_train, X_test, y_train, y_test, le, scaler, class_names)
223
+
224
+ # ========================
225
+ # Train all models
226
+ # ========================
227
+ all_results = {}
228
+
229
+ # 1. MLP Baseline
230
+ mlp = MLP_IDS(in_dim=41, num_classes=NUM_CLASSES)
231
+ mlp, mlp_results = train_model(mlp, 'mlp', X_train, y_train, X_test, y_test, class_names)
232
+ all_results['mlp'] = mlp_results
233
+
234
+ # 2. LSTM
235
+ lstm = LSTM_IDS(in_dim=41, num_classes=NUM_CLASSES)
236
+ lstm, lstm_results = train_model(lstm, 'lstm', X_train, y_train, X_test, y_test, class_names)
237
+ all_results['lstm'] = lstm_results
238
+
239
+ # 3. 1D-CNN
240
+ cnn = CNN1D_IDS(in_dim=41, num_classes=NUM_CLASSES)
241
+ cnn, cnn_results = train_model(cnn, 'cnn1d', X_train, y_train, X_test, y_test, class_names)
242
+ all_results['cnn1d'] = cnn_results
243
+
244
+ # ========================
245
+ # Save results
246
+ # ========================
247
+ os.makedirs(RESULTS_DIR, exist_ok=True)
248
+
249
+ def convert(o):
250
+ if isinstance(o, np.floating): return float(o)
251
+ if isinstance(o, np.integer): return int(o)
252
+ if isinstance(o, np.ndarray): return o.tolist()
253
+ return o
254
+
255
+ with open(os.path.join(RESULTS_DIR, 'training_results.json'), 'w') as f:
256
+ json.dump(all_results, f, indent=2, default=convert)
257
+
258
+ # ========================
259
+ # Summary comparison
260
+ # ========================
261
+ print("\n" + "="*60)
262
+ print("MODEL COMPARISON SUMMARY")
263
+ print("="*60)
264
+ print(f"{'Model':<10} {'Params':>8} {'Accuracy':>10} {'W-F1':>8} {'ROC-AUC':>9} {'PR-AUC':>8} {'Time':>8}")
265
+ print("-"*60)
266
+
267
+ for name, res in all_results.items():
268
+ acc = res['classification_report']['accuracy']
269
+ wf1 = res['best_weighted_f1']
270
+ roc = res.get('roc_auc', 0)
271
+ pr = res.get('pr_auc', 0)
272
+ t = res['training_time']
273
+ p = res['parameters']
274
+ print(f"{name:<10} {p:>8,} {acc:>10.4f} {wf1:>8.4f} {roc:>9.4f} {pr:>8.4f} {t:>7.1f}s")
275
+
276
+ print("\nAll models trained successfully!")
277
+ print(f"Results saved to {RESULTS_DIR}/training_results.json")
278
+ print(f"Models saved to {MODELS_DIR}/")
279
+
280
+
281
+ if __name__ == '__main__':
282
+ main()