cathrica commited on
Commit
f5ee9d3
·
verified ·
1 Parent(s): 0d2d577

Add Colab notebook — full pipeline in one file

Browse files
Files changed (1) hide show
  1. explainable_ids_full_pipeline.ipynb +743 -0
explainable_ids_full_pipeline.ipynb ADDED
@@ -0,0 +1,743 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Explainable IDS — Full Pipeline\n",
8
+ "**ICCN-INE2 Project 5 | NSL-KDD | MLP + LSTM + 1D-CNN | SHAP + LIME**\n",
9
+ "\n",
10
+ "Run all cells (Runtime → Run all) or Ctrl+F9. Takes ~10-15 min on Colab GPU."
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "metadata": {},
16
+ "source": [
17
+ "## 0. Setup"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": null,
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "!pip install -q torch numpy pandas scikit-learn datasets shap lime matplotlib scipy"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": null,
32
+ "metadata": {},
33
+ "outputs": [],
34
+ "source": [
35
+ "import os, sys, json, time, random, pickle\n",
36
+ "import numpy as np\n",
37
+ "import pandas as pd\n",
38
+ "import torch\n",
39
+ "import torch.nn as nn\n",
40
+ "from torch.utils.data import TensorDataset, DataLoader\n",
41
+ "from sklearn.preprocessing import LabelEncoder, MinMaxScaler\n",
42
+ "from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, average_precision_score\n",
43
+ "from datasets import load_dataset\n",
44
+ "import shap\n",
45
+ "from lime import lime_tabular\n",
46
+ "from scipy.stats import spearmanr, pearsonr\n",
47
+ "import matplotlib.pyplot as plt\n",
48
+ "import warnings\n",
49
+ "warnings.filterwarnings('ignore')\n",
50
+ "\n",
51
+ "# Reproducibility\n",
52
+ "SEED = 42\n",
53
+ "random.seed(SEED)\n",
54
+ "np.random.seed(SEED)\n",
55
+ "torch.manual_seed(SEED)\n",
56
+ "torch.backends.cudnn.deterministic = True\n",
57
+ "torch.backends.cudnn.benchmark = False\n",
58
+ "\n",
59
+ "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
60
+ "print(f'Device: {DEVICE}')\n",
61
+ "if DEVICE.type == 'cuda':\n",
62
+ " print(f'GPU: {torch.cuda.get_device_name(0)}')"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "markdown",
67
+ "metadata": {},
68
+ "source": [
69
+ "## 1. Load & Preprocess NSL-KDD"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "execution_count": null,
75
+ "metadata": {},
76
+ "outputs": [],
77
+ "source": [
78
+ "FEATURE_NAMES = [\n",
79
+ " 'duration', 'protocol_type', 'service', 'flag',\n",
80
+ " 'src_bytes', 'dst_bytes', 'land', 'wrong_fragment', 'urgent',\n",
81
+ " 'hot', 'num_failed_logins', 'logged_in', 'num_compromised',\n",
82
+ " 'root_shell', 'su_attempted', 'num_root', 'num_file_creations',\n",
83
+ " 'num_shells', 'num_access_files', 'num_outbound_cmds',\n",
84
+ " 'is_host_login', 'is_guest_login',\n",
85
+ " 'count', 'srv_count',\n",
86
+ " 'serror_rate', 'srv_serror_rate', 'rerror_rate', 'srv_rerror_rate',\n",
87
+ " 'same_srv_rate', 'diff_srv_rate', 'srv_diff_host_rate',\n",
88
+ " 'dst_host_count', 'dst_host_srv_count',\n",
89
+ " 'dst_host_same_srv_rate', 'dst_host_diff_srv_rate',\n",
90
+ " 'dst_host_same_src_port_rate', 'dst_host_srv_diff_host_rate',\n",
91
+ " 'dst_host_serror_rate', 'dst_host_srv_serror_rate',\n",
92
+ " 'dst_host_rerror_rate', 'dst_host_srv_rerror_rate'\n",
93
+ "]\n",
94
+ "CATEGORICAL_COLS = ['protocol_type', 'service', 'flag']\n",
95
+ "\n",
96
+ "# Load from HuggingFace\n",
97
+ "ds = load_dataset('Mireu-Lab/NSL-KDD')\n",
98
+ "df_train = ds['train'].to_pandas()\n",
99
+ "df_test = ds['test'].to_pandas()\n",
100
+ "print(f'Train: {len(df_train)} | Test: {len(df_test)}')\n",
101
+ "\n",
102
+ "# Class distribution\n",
103
+ "print('\\nTrain distribution:')\n",
104
+ "print(df_train['class'].value_counts())\n",
105
+ "print('\\nTest distribution:')\n",
106
+ "print(df_test['class'].value_counts())"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": null,
112
+ "metadata": {},
113
+ "outputs": [],
114
+ "source": [
115
+ "# Encode target (binary: anomaly=0, normal=1)\n",
116
+ "class_names = ['anomaly', 'normal']\n",
117
+ "le_y = LabelEncoder()\n",
118
+ "y_train = le_y.fit_transform(df_train['class'].values)\n",
119
+ "y_test = le_y.transform(df_test['class'].values)\n",
120
+ "\n",
121
+ "# Encode categoricals\n",
122
+ "df_tr, df_te = df_train.copy(), df_test.copy()\n",
123
+ "label_encoders = {}\n",
124
+ "for col in CATEGORICAL_COLS:\n",
125
+ " le = LabelEncoder()\n",
126
+ " le.fit(df_tr[col])\n",
127
+ " known = set(le.classes_)\n",
128
+ " df_te[col] = df_te[col].apply(lambda x: x if x in known else le.classes_[0])\n",
129
+ " df_tr[col] = le.transform(df_tr[col])\n",
130
+ " df_te[col] = le.transform(df_te[col])\n",
131
+ " label_encoders[col] = le\n",
132
+ " print(f'Encoded {col}: {len(le.classes_)} categories')\n",
133
+ "\n",
134
+ "# Scale features\n",
135
+ "scaler = MinMaxScaler()\n",
136
+ "X_train = scaler.fit_transform(df_tr[FEATURE_NAMES].values.astype(np.float32))\n",
137
+ "X_test = scaler.transform(df_te[FEATURE_NAMES].values.astype(np.float32))\n",
138
+ "\n",
139
+ "print(f'\\nX_train: {X_train.shape} | X_test: {X_test.shape}')\n",
140
+ "print(f'y_train: {np.bincount(y_train)} | y_test: {np.bincount(y_test)}')"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "markdown",
145
+ "metadata": {},
146
+ "source": [
147
+ "## 2. Model Definitions"
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "code",
152
+ "execution_count": null,
153
+ "metadata": {},
154
+ "outputs": [],
155
+ "source": [
156
+ "class MLP_IDS(nn.Module):\n",
157
+ " def __init__(self, in_dim=41, num_classes=2):\n",
158
+ " super().__init__()\n",
159
+ " self.net = nn.Sequential(\n",
160
+ " nn.Linear(in_dim, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3),\n",
161
+ " nn.Linear(256, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(0.2),\n",
162
+ " nn.Linear(128, 64), nn.ReLU(),\n",
163
+ " nn.Linear(64, num_classes)\n",
164
+ " )\n",
165
+ " for m in self.modules():\n",
166
+ " if isinstance(m, nn.Linear):\n",
167
+ " nn.init.xavier_uniform_(m.weight)\n",
168
+ " nn.init.zeros_(m.bias)\n",
169
+ " def forward(self, x): return self.net(x)\n",
170
+ " def count_parameters(self): return sum(p.numel() for p in self.parameters() if p.requires_grad)\n",
171
+ "\n",
172
+ "class LSTM_IDS(nn.Module):\n",
173
+ " def __init__(self, in_dim=41, hidden_dim=64, num_layers=2, num_classes=2):\n",
174
+ " super().__init__()\n",
175
+ " self.lstm = nn.LSTM(1, hidden_dim, num_layers, batch_first=True, dropout=0.2)\n",
176
+ " self.fc = nn.Sequential(nn.Linear(hidden_dim, 32), nn.ReLU(), nn.Linear(32, num_classes))\n",
177
+ " def forward(self, x):\n",
178
+ " out, (h_n, _) = self.lstm(x.unsqueeze(-1))\n",
179
+ " return self.fc(h_n[-1])\n",
180
+ " def count_parameters(self): return sum(p.numel() for p in self.parameters() if p.requires_grad)\n",
181
+ "\n",
182
+ "class CNN1D_IDS(nn.Module):\n",
183
+ " def __init__(self, in_dim=41, num_classes=2):\n",
184
+ " super().__init__()\n",
185
+ " self.conv = nn.Sequential(\n",
186
+ " nn.Conv1d(1, 64, 3, padding=1), nn.BatchNorm1d(64), nn.ReLU(),\n",
187
+ " nn.Conv1d(64, 128, 3, padding=1), nn.BatchNorm1d(128), nn.ReLU(),\n",
188
+ " nn.AdaptiveAvgPool1d(8)\n",
189
+ " )\n",
190
+ " self.fc = nn.Sequential(nn.Linear(128*8, 64), nn.ReLU(), nn.Dropout(0.2), nn.Linear(64, num_classes))\n",
191
+ " def forward(self, x):\n",
192
+ " x = self.conv(x.unsqueeze(1))\n",
193
+ " return self.fc(x.view(x.size(0), -1))\n",
194
+ " def count_parameters(self): return sum(p.numel() for p in self.parameters() if p.requires_grad)\n",
195
+ "\n",
196
+ "for name, cls in [('MLP', MLP_IDS), ('LSTM', LSTM_IDS), ('CNN1D', CNN1D_IDS)]:\n",
197
+ " m = cls()\n",
198
+ " print(f'{name}: {m.count_parameters():,} parameters')"
199
+ ]
200
+ },
201
+ {
202
+ "cell_type": "markdown",
203
+ "metadata": {},
204
+ "source": [
205
+ "## 3. Train All Models"
206
+ ]
207
+ },
208
+ {
209
+ "cell_type": "code",
210
+ "execution_count": null,
211
+ "metadata": {},
212
+ "outputs": [],
213
+ "source": [
214
+ "EPOCHS = 50\n",
215
+ "BATCH_SIZE = 256\n",
216
+ "LR = 1e-3\n",
217
+ "\n",
218
+ "# Data loaders\n",
219
+ "train_ds = TensorDataset(torch.FloatTensor(X_train), torch.LongTensor(y_train))\n",
220
+ "test_ds = TensorDataset(torch.FloatTensor(X_test), torch.LongTensor(y_test))\n",
221
+ "train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)\n",
222
+ "test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)\n",
223
+ "\n",
224
+ "# Class weights\n",
225
+ "counts = np.bincount(y_train)\n",
226
+ "weights = 1.0 / counts.astype(np.float32)\n",
227
+ "weights = weights / weights.sum() * len(weights)\n",
228
+ "class_weights = torch.FloatTensor(weights).to(DEVICE)\n",
229
+ "\n",
230
+ "def train_model(model, model_name):\n",
231
+ " print(f'\\n{\"=\"*60}')\n",
232
+ " print(f'Training {model_name} ({model.count_parameters():,} params) on {DEVICE}')\n",
233
+ " print(f'{\"=\"*60}')\n",
234
+ " \n",
235
+ " model.to(DEVICE)\n",
236
+ " criterion = nn.CrossEntropyLoss(weight=class_weights)\n",
237
+ " optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-4)\n",
238
+ " scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)\n",
239
+ " \n",
240
+ " best_f1, history = 0, {'train_loss': [], 'test_acc': []}\n",
241
+ " best_state = None\n",
242
+ " t0 = time.time()\n",
243
+ " \n",
244
+ " for epoch in range(EPOCHS):\n",
245
+ " model.train()\n",
246
+ " total_loss = 0\n",
247
+ " for xb, yb in train_loader:\n",
248
+ " xb, yb = xb.to(DEVICE), yb.to(DEVICE)\n",
249
+ " optimizer.zero_grad()\n",
250
+ " loss = criterion(model(xb), yb)\n",
251
+ " loss.backward()\n",
252
+ " optimizer.step()\n",
253
+ " total_loss += loss.item() * len(yb)\n",
254
+ " \n",
255
+ " # Evaluate\n",
256
+ " model.eval()\n",
257
+ " preds, probs, labels = [], [], []\n",
258
+ " with torch.no_grad():\n",
259
+ " for xb, yb in test_loader:\n",
260
+ " xb = xb.to(DEVICE)\n",
261
+ " out = model(xb)\n",
262
+ " preds.append(out.argmax(1).cpu().numpy())\n",
263
+ " probs.append(torch.softmax(out, 1).cpu().numpy())\n",
264
+ " labels.append(yb.numpy())\n",
265
+ " preds = np.concatenate(preds)\n",
266
+ " probs = np.concatenate(probs)\n",
267
+ " labels = np.concatenate(labels)\n",
268
+ " \n",
269
+ " report = classification_report(labels, preds, output_dict=True)\n",
270
+ " wf1 = report['weighted avg']['f1-score']\n",
271
+ " acc = report['accuracy']\n",
272
+ " test_loss = total_loss / len(y_train)\n",
273
+ " scheduler.step(test_loss)\n",
274
+ " \n",
275
+ " history['train_loss'].append(total_loss / len(y_train))\n",
276
+ " history['test_acc'].append(acc)\n",
277
+ " \n",
278
+ " if wf1 > best_f1:\n",
279
+ " best_f1 = wf1\n",
280
+ " best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}\n",
281
+ " \n",
282
+ " if (epoch+1) % 10 == 0 or epoch == 0:\n",
283
+ " print(f' Epoch {epoch+1:3d}/{EPOCHS} | Loss: {total_loss/len(y_train):.4f} | Acc: {acc:.4f} | F1: {wf1:.4f}')\n",
284
+ " \n",
285
+ " dt = time.time() - t0\n",
286
+ " \n",
287
+ " # Load best and final eval\n",
288
+ " model.load_state_dict(best_state)\n",
289
+ " model.eval()\n",
290
+ " preds, probs, labels = [], [], []\n",
291
+ " with torch.no_grad():\n",
292
+ " for xb, yb in test_loader:\n",
293
+ " xb = xb.to(DEVICE)\n",
294
+ " out = model(xb)\n",
295
+ " preds.append(out.argmax(1).cpu().numpy())\n",
296
+ " probs.append(torch.softmax(out, 1).cpu().numpy())\n",
297
+ " labels.append(yb.numpy())\n",
298
+ " preds = np.concatenate(preds)\n",
299
+ " probs = np.concatenate(probs)\n",
300
+ " labels = np.concatenate(labels)\n",
301
+ " \n",
302
+ " roc = roc_auc_score(labels, probs[:, 1])\n",
303
+ " pr = average_precision_score(labels, probs[:, 1])\n",
304
+ " \n",
305
+ " print(f'\\n Time: {dt:.1f}s | Best F1: {best_f1:.4f} | ROC-AUC: {roc:.4f} | PR-AUC: {pr:.4f}')\n",
306
+ " print(classification_report(labels, preds, target_names=class_names))\n",
307
+ " print('Confusion Matrix:')\n",
308
+ " print(confusion_matrix(labels, preds))\n",
309
+ " \n",
310
+ " return model, {'f1': best_f1, 'roc_auc': roc, 'pr_auc': pr, 'time': dt, 'history': history, 'preds': preds, 'probs': probs, 'labels': labels}\n",
311
+ "\n",
312
+ "# Train all 3\n",
313
+ "models = {}\n",
314
+ "results = {}\n",
315
+ "for name, cls in [('mlp', MLP_IDS), ('lstm', LSTM_IDS), ('cnn1d', CNN1D_IDS)]:\n",
316
+ " models[name], results[name] = train_model(cls(), name.upper())"
317
+ ]
318
+ },
319
+ {
320
+ "cell_type": "code",
321
+ "execution_count": null,
322
+ "metadata": {},
323
+ "outputs": [],
324
+ "source": [
325
+ "# Summary table\n",
326
+ "print(f'{\"Model\":<8} {\"Params\":>8} {\"W-F1\":>8} {\"ROC-AUC\":>9} {\"PR-AUC\":>8} {\"Time\":>8}')\n",
327
+ "print('-'*50)\n",
328
+ "for name in ['mlp', 'lstm', 'cnn1d']:\n",
329
+ " r = results[name]\n",
330
+ " p = models[name].count_parameters()\n",
331
+ " print(f'{name:<8} {p:>8,} {r[\"f1\"]:>8.4f} {r[\"roc_auc\"]:>9.4f} {r[\"pr_auc\"]:>8.4f} {r[\"time\"]:>7.1f}s')"
332
+ ]
333
+ },
334
+ {
335
+ "cell_type": "code",
336
+ "execution_count": null,
337
+ "metadata": {},
338
+ "outputs": [],
339
+ "source": [
340
+ "# Training curves\n",
341
+ "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
342
+ "for name in ['mlp', 'lstm', 'cnn1d']:\n",
343
+ " axes[0].plot(results[name]['history']['train_loss'], label=name.upper())\n",
344
+ " axes[1].plot(results[name]['history']['test_acc'], label=name.upper())\n",
345
+ "axes[0].set_xlabel('Epoch'); axes[0].set_ylabel('Train Loss'); axes[0].set_title('Training Loss'); axes[0].legend(); axes[0].grid(alpha=0.3)\n",
346
+ "axes[1].set_xlabel('Epoch'); axes[1].set_ylabel('Test Accuracy'); axes[1].set_title('Test Accuracy'); axes[1].legend(); axes[1].grid(alpha=0.3)\n",
347
+ "plt.tight_layout(); plt.show()"
348
+ ]
349
+ },
350
+ {
351
+ "cell_type": "markdown",
352
+ "metadata": {},
353
+ "source": [
354
+ "## 4. SHAP Explainability Analysis"
355
+ ]
356
+ },
357
+ {
358
+ "cell_type": "code",
359
+ "execution_count": null,
360
+ "metadata": {},
361
+ "outputs": [],
362
+ "source": [
363
+ "# Move MLP to CPU for SHAP\n",
364
+ "mlp_cpu = models['mlp'].cpu().eval()\n",
365
+ "\n",
366
+ "def predict_fn(X):\n",
367
+ " with torch.no_grad():\n",
368
+ " return torch.softmax(mlp_cpu(torch.FloatTensor(X)), 1).numpy()\n",
369
+ "\n",
370
+ "# Background & samples\n",
371
+ "bg_idx = np.random.choice(len(X_train), 100, replace=False)\n",
372
+ "exp_idx = np.random.choice(len(X_test), 150, replace=False)\n",
373
+ "\n",
374
+ "explainer = shap.KernelExplainer(predict_fn, X_train[bg_idx])\n",
375
+ "print('Computing SHAP values for 150 test samples (this takes a few minutes)...')\n",
376
+ "shap_values = explainer.shap_values(X_test[exp_idx], nsamples=200, silent=True)\n",
377
+ "print('Done!')"
378
+ ]
379
+ },
380
+ {
381
+ "cell_type": "code",
382
+ "execution_count": null,
383
+ "metadata": {},
384
+ "outputs": [],
385
+ "source": [
386
+ "# Global feature importance (anomaly class)\n",
387
+ "mean_abs_shap = np.abs(shap_values[0]).mean(axis=0)\n",
388
+ "feature_importance = sorted(zip(FEATURE_NAMES, mean_abs_shap), key=lambda x: x[1], reverse=True)\n",
389
+ "\n",
390
+ "print('Top 15 features by mean |SHAP| (anomaly class):')\n",
391
+ "for i, (f, v) in enumerate(feature_importance[:15]):\n",
392
+ " print(f' {i+1:2d}. {f:35s} {v:.4f}')"
393
+ ]
394
+ },
395
+ {
396
+ "cell_type": "code",
397
+ "execution_count": null,
398
+ "metadata": {},
399
+ "outputs": [],
400
+ "source": [
401
+ "# SHAP summary plot\n",
402
+ "shap.summary_plot(shap_values[0], X_test[exp_idx], feature_names=FEATURE_NAMES, max_display=15)"
403
+ ]
404
+ },
405
+ {
406
+ "cell_type": "code",
407
+ "execution_count": null,
408
+ "metadata": {},
409
+ "outputs": [],
410
+ "source": [
411
+ "# SHAP bar plot\n",
412
+ "plt.figure(figsize=(10, 6))\n",
413
+ "top15 = feature_importance[:15]\n",
414
+ "plt.barh(range(15), [v for _, v in top15][::-1], color='steelblue')\n",
415
+ "plt.yticks(range(15), [f for f, _ in top15][::-1])\n",
416
+ "plt.xlabel('Mean |SHAP value|')\n",
417
+ "plt.title('Top 15 Features — MLP (Anomaly Class)')\n",
418
+ "plt.tight_layout(); plt.show()"
419
+ ]
420
+ },
421
+ {
422
+ "cell_type": "code",
423
+ "execution_count": null,
424
+ "metadata": {},
425
+ "outputs": [],
426
+ "source": [
427
+ "# Single prediction explanation (force plot)\n",
428
+ "idx = 0\n",
429
+ "pred = predict_fn(X_test[exp_idx[idx:idx+1]])\n",
430
+ "print(f'Sample prediction: anomaly={pred[0][0]:.3f}, normal={pred[0][1]:.3f}')\n",
431
+ "print(f'True label: {class_names[y_test[exp_idx[idx]]]}')\n",
432
+ "shap.force_plot(explainer.expected_value[0], shap_values[0][idx], X_test[exp_idx[idx]], feature_names=FEATURE_NAMES, matplotlib=True)"
433
+ ]
434
+ },
435
+ {
436
+ "cell_type": "markdown",
437
+ "metadata": {},
438
+ "source": [
439
+ "## 5. LIME Analysis"
440
+ ]
441
+ },
442
+ {
443
+ "cell_type": "code",
444
+ "execution_count": null,
445
+ "metadata": {},
446
+ "outputs": [],
447
+ "source": [
448
+ "lime_explainer = lime_tabular.LimeTabularExplainer(\n",
449
+ " X_train, feature_names=FEATURE_NAMES, class_names=class_names,\n",
450
+ " discretize_continuous=True, random_state=SEED\n",
451
+ ")\n",
452
+ "\n",
453
+ "n_lime = 30\n",
454
+ "lime_idx = np.random.choice(len(X_test), n_lime, replace=False)\n",
455
+ "all_top_features = {}\n",
456
+ "\n",
457
+ "print(f'Running LIME on {n_lime} samples...')\n",
458
+ "for i, idx in enumerate(lime_idx):\n",
459
+ " exp = lime_explainer.explain_instance(X_test[idx], predict_fn, num_features=10, top_labels=1)\n",
460
+ " pred_class = np.argmax(predict_fn(X_test[idx].reshape(1, -1)))\n",
461
+ " for fw in exp.as_list(label=pred_class):\n",
462
+ " fname = fw[0].split(' ')[0]\n",
463
+ " all_top_features[fname] = all_top_features.get(fname, 0) + 1\n",
464
+ " if (i+1) % 10 == 0:\n",
465
+ " print(f' {i+1}/{n_lime} done')\n",
466
+ "\n",
467
+ "lime_sorted = sorted(all_top_features.items(), key=lambda x: x[1], reverse=True)\n",
468
+ "print(f'\\nTop 10 features by LIME frequency:')\n",
469
+ "for f, c in lime_sorted[:10]:\n",
470
+ " print(f' {f:35s}: {c}/{n_lime} explanations')"
471
+ ]
472
+ },
473
+ {
474
+ "cell_type": "code",
475
+ "execution_count": null,
476
+ "metadata": {},
477
+ "outputs": [],
478
+ "source": [
479
+ "# LIME vs SHAP comparison\n",
480
+ "fig, axes = plt.subplots(1, 2, figsize=(16, 6))\n",
481
+ "\n",
482
+ "# SHAP\n",
483
+ "top10_shap = feature_importance[:10]\n",
484
+ "axes[0].barh(range(10), [v for _, v in top10_shap][::-1], color='steelblue')\n",
485
+ "axes[0].set_yticks(range(10)); axes[0].set_yticklabels([f for f, _ in top10_shap][::-1])\n",
486
+ "axes[0].set_xlabel('Mean |SHAP value|'); axes[0].set_title('SHAP Top 10')\n",
487
+ "\n",
488
+ "# LIME\n",
489
+ "top10_lime = lime_sorted[:10]\n",
490
+ "axes[1].barh(range(10), [v for _, v in top10_lime][::-1], color='coral')\n",
491
+ "axes[1].set_yticks(range(10)); axes[1].set_yticklabels([f for f, _ in top10_lime][::-1])\n",
492
+ "axes[1].set_xlabel(f'Frequency in top-10 (out of {n_lime})'); axes[1].set_title('LIME Top 10')\n",
493
+ "\n",
494
+ "plt.suptitle('SHAP vs LIME Feature Rankings', fontsize=14)\n",
495
+ "plt.tight_layout(); plt.show()\n",
496
+ "\n",
497
+ "# Rank correlation\n",
498
+ "shap_ranks = {f: i for i, (f, _) in enumerate(feature_importance[:20])}\n",
499
+ "lime_ranks = {f: i for i, (f, _) in enumerate(lime_sorted[:20])}\n",
500
+ "common = set(shap_ranks.keys()) & set(lime_ranks.keys())\n",
501
+ "if len(common) >= 5:\n",
502
+ " rho, p = spearmanr([shap_ranks[f] for f in common], [lime_ranks[f] for f in common])\n",
503
+ " print(f'\\nSHAP vs LIME Spearman correlation: {rho:.4f} (p={p:.4f}) over {len(common)} common features')"
504
+ ]
505
+ },
506
+ {
507
+ "cell_type": "markdown",
508
+ "metadata": {},
509
+ "source": [
510
+ "## 6. Explanation Stability Evaluation"
511
+ ]
512
+ },
513
+ {
514
+ "cell_type": "code",
515
+ "execution_count": null,
516
+ "metadata": {},
517
+ "outputs": [],
518
+ "source": [
519
+ "def compute_shap_stability(explainer, sample, epsilon, n_perturbs=10):\n",
520
+ " \"\"\"Compute SENS_MAX and PCC for one sample.\"\"\"\n",
521
+ " rng = np.random.RandomState(SEED)\n",
522
+ " base = np.array(explainer.shap_values(sample.reshape(1,-1), nsamples=100, silent=True))\n",
523
+ " base = base[0].flatten() if isinstance(base, list) else base.flatten()\n",
524
+ " \n",
525
+ " max_delta, pccs = 0, []\n",
526
+ " for _ in range(n_perturbs):\n",
527
+ " noise = rng.uniform(-epsilon, epsilon, sample.shape)\n",
528
+ " perturbed = np.clip(sample + noise, 0, 1)\n",
529
+ " p_shap = np.array(explainer.shap_values(perturbed.reshape(1,-1), nsamples=100, silent=True))\n",
530
+ " p_shap = p_shap[0].flatten() if isinstance(p_shap, list) else p_shap.flatten()\n",
531
+ " max_delta = max(max_delta, np.linalg.norm(p_shap - base))\n",
532
+ " if np.std(base) > 1e-8 and np.std(p_shap) > 1e-8:\n",
533
+ " pccs.append(pearsonr(base, p_shap)[0])\n",
534
+ " return max_delta, np.mean(pccs) if pccs else 0.0\n",
535
+ "\n",
536
+ "# Test across epsilon values\n",
537
+ "epsilons = [0.01, 0.03, 0.05]\n",
538
+ "n_stability = 8\n",
539
+ "stability_idx = np.random.choice(len(X_test), n_stability, replace=False)\n",
540
+ "stability_results = {}\n",
541
+ "\n",
542
+ "for eps in epsilons:\n",
543
+ " sens_list, pcc_list = [], []\n",
544
+ " print(f'\\n--- SHAP Stability (eps={eps}) ---')\n",
545
+ " for i, idx in enumerate(stability_idx):\n",
546
+ " sm, pc = compute_shap_stability(explainer, X_test[idx], eps, n_perturbs=8)\n",
547
+ " sens_list.append(sm); pcc_list.append(pc)\n",
548
+ " if (i+1) % 4 == 0:\n",
549
+ " print(f' {i+1}/{n_stability} | SENS_MAX={sm:.4f} | PCC={pc:.4f}')\n",
550
+ " \n",
551
+ " stability_results[eps] = {'sens_max': np.mean(sens_list), 'pcc': np.mean(pcc_list)}\n",
552
+ " status = 'STABLE' if np.mean(pcc_list) > 0.6 else 'UNSTABLE'\n",
553
+ " print(f' Mean SENS_MAX={np.mean(sens_list):.4f} | Mean PCC={np.mean(pcc_list):.4f} [{status}]')"
554
+ ]
555
+ },
556
+ {
557
+ "cell_type": "code",
558
+ "execution_count": null,
559
+ "metadata": {},
560
+ "outputs": [],
561
+ "source": [
562
+ "# LIME stochastic stability\n",
563
+ "print('--- LIME Stochastic Stability ---')\n",
564
+ "lime_corrs = []\n",
565
+ "for i, idx in enumerate(stability_idx[:6]):\n",
566
+ " weight_vecs = []\n",
567
+ " for seed in range(10):\n",
568
+ " le_obj = lime_tabular.LimeTabularExplainer(X_train, feature_names=FEATURE_NAMES, discretize_continuous=True, random_state=seed)\n",
569
+ " exp = le_obj.explain_instance(X_test[idx], predict_fn, num_features=len(FEATURE_NAMES))\n",
570
+ " w = np.zeros(len(FEATURE_NAMES))\n",
571
+ " for key, val in dict(exp.as_list()).items():\n",
572
+ " for j, fn in enumerate(FEATURE_NAMES):\n",
573
+ " if fn in key: w[j] = val; break\n",
574
+ " weight_vecs.append(w)\n",
575
+ " corrs = []\n",
576
+ " for a in range(10):\n",
577
+ " for b in range(a+1, 10):\n",
578
+ " if np.std(weight_vecs[a]) > 1e-8 and np.std(weight_vecs[b]) > 1e-8:\n",
579
+ " corrs.append(spearmanr(weight_vecs[a], weight_vecs[b])[0])\n",
580
+ " mc = np.mean(corrs) if corrs else 0\n",
581
+ " lime_corrs.append(mc)\n",
582
+ " print(f' Sample {i+1}/6 | Mean Spearman: {mc:.4f}')\n",
583
+ "\n",
584
+ "lime_status = 'STABLE' if np.mean(lime_corrs) > 0.6 else 'UNSTABLE'\n",
585
+ "print(f'\\nOverall LIME stability: {np.mean(lime_corrs):.4f} [{lime_status}]')"
586
+ ]
587
+ },
588
+ {
589
+ "cell_type": "code",
590
+ "execution_count": null,
591
+ "metadata": {},
592
+ "outputs": [],
593
+ "source": [
594
+ "# Faithfulness evaluation\n",
595
+ "print('--- Faithfulness (Feature Masking) ---')\n",
596
+ "faith_results = {k: [] for k in [3, 5, 10]}\n",
597
+ "\n",
598
+ "for idx in stability_idx[:10]:\n",
599
+ " sample = X_test[idx]\n",
600
+ " sv = np.array(explainer.shap_values(sample.reshape(1,-1), nsamples=100, silent=True))\n",
601
+ " sv = sv[0].flatten() if isinstance(sv, list) else sv.flatten()\n",
602
+ " \n",
603
+ " base_conf = predict_fn(sample.reshape(1,-1))[0]\n",
604
+ " pred_cls = np.argmax(base_conf)\n",
605
+ " \n",
606
+ " for k in faith_results:\n",
607
+ " masked = sample.copy()\n",
608
+ " masked[np.argsort(np.abs(sv))[-k:]] = 0.0\n",
609
+ " drop = base_conf[pred_cls] - predict_fn(masked.reshape(1,-1))[0][pred_cls]\n",
610
+ " faith_results[k].append(float(drop))\n",
611
+ "\n",
612
+ "for k, scores in faith_results.items():\n",
613
+ " print(f' Top-{k} masking: confidence drop = {np.mean(scores):.4f} +/- {np.std(scores):.4f}')"
614
+ ]
615
+ },
616
+ {
617
+ "cell_type": "code",
618
+ "execution_count": null,
619
+ "metadata": {},
620
+ "outputs": [],
621
+ "source": [
622
+ "# Stability summary plot\n",
623
+ "fig, axes = plt.subplots(1, 3, figsize=(16, 5))\n",
624
+ "\n",
625
+ "# SENS_MAX\n",
626
+ "eps_list = list(stability_results.keys())\n",
627
+ "axes[0].plot(eps_list, [stability_results[e]['sens_max'] for e in eps_list], 'o-', color='steelblue', markersize=8)\n",
628
+ "axes[0].set_xlabel('Perturbation epsilon'); axes[0].set_ylabel('SENS_MAX')\n",
629
+ "axes[0].set_title('SHAP Sensitivity (lower = more stable)'); axes[0].grid(alpha=0.3)\n",
630
+ "\n",
631
+ "# PCC\n",
632
+ "pcc_vals = [stability_results[e]['pcc'] for e in eps_list]\n",
633
+ "colors = ['green' if p > 0.6 else 'red' for p in pcc_vals]\n",
634
+ "axes[1].bar(range(len(eps_list)), pcc_vals, color=colors)\n",
635
+ "axes[1].set_xticks(range(len(eps_list))); axes[1].set_xticklabels([f'eps={e}' for e in eps_list])\n",
636
+ "axes[1].axhline(y=0.6, color='gray', linestyle='--', label='Threshold (0.6)')\n",
637
+ "axes[1].set_ylabel('Mean PCC'); axes[1].set_title('SHAP Stability'); axes[1].legend()\n",
638
+ "\n",
639
+ "# Faithfulness\n",
640
+ "ks = list(faith_results.keys())\n",
641
+ "axes[2].bar(range(len(ks)), [np.mean(faith_results[k]) for k in ks],\n",
642
+ " yerr=[np.std(faith_results[k]) for k in ks], color='coral', capsize=5)\n",
643
+ "axes[2].set_xticks(range(len(ks))); axes[2].set_xticklabels([f'Top-{k}' for k in ks])\n",
644
+ "axes[2].set_ylabel('Confidence drop'); axes[2].set_title('Faithfulness (higher = better)')\n",
645
+ "\n",
646
+ "plt.suptitle('Explanation Stability Evaluation (SAFARI Framework)', fontsize=14)\n",
647
+ "plt.tight_layout(); plt.show()"
648
+ ]
649
+ },
650
+ {
651
+ "cell_type": "markdown",
652
+ "metadata": {},
653
+ "source": [
654
+ "## 7. Security Implications Summary"
655
+ ]
656
+ },
657
+ {
658
+ "cell_type": "code",
659
+ "execution_count": null,
660
+ "metadata": {},
661
+ "outputs": [],
662
+ "source": [
663
+ "# Analyze which top SHAP features are attacker-manipulable\n",
664
+ "manipulable = {'src_bytes', 'dst_bytes', 'hot', 'num_failed_logins', 'duration', 'num_compromised',\n",
665
+ " 'root_shell', 'su_attempted', 'num_root', 'num_file_creations', 'num_shells', 'num_access_files'}\n",
666
+ "partial = {'count', 'srv_count', 'serror_rate', 'srv_serror_rate', 'rerror_rate', 'srv_rerror_rate',\n",
667
+ " 'protocol_type', 'flag', 'service'}\n",
668
+ "non_manip = {'dst_host_count', 'dst_host_srv_count', 'dst_host_same_srv_rate', 'dst_host_diff_srv_rate',\n",
669
+ " 'dst_host_same_src_port_rate', 'dst_host_srv_diff_host_rate', 'dst_host_serror_rate',\n",
670
+ " 'dst_host_srv_serror_rate', 'dst_host_rerror_rate', 'dst_host_srv_rerror_rate',\n",
671
+ " 'same_srv_rate', 'diff_srv_rate', 'srv_diff_host_rate'}\n",
672
+ "\n",
673
+ "print('SECURITY ANALYSIS: Top 15 Features by Manipulability')\n",
674
+ "print('='*70)\n",
675
+ "manip_count = {'Manipulable': 0, 'Partial': 0, 'Non-manipulable': 0}\n",
676
+ "for i, (f, v) in enumerate(feature_importance[:15]):\n",
677
+ " if f in manipulable:\n",
678
+ " status = 'MANIPULABLE'\n",
679
+ " manip_count['Manipulable'] += 1\n",
680
+ " elif f in partial:\n",
681
+ " status = 'PARTIAL'\n",
682
+ " manip_count['Partial'] += 1\n",
683
+ " else:\n",
684
+ " status = 'NON-MANIPULABLE'\n",
685
+ " manip_count['Non-manipulable'] += 1\n",
686
+ " print(f' {i+1:2d}. {f:35s} SHAP={v:.4f} [{status}]')\n",
687
+ "\n",
688
+ "print(f'\\nSummary: {manip_count}')\n",
689
+ "if manip_count['Non-manipulable'] > manip_count['Manipulable']:\n",
690
+ " print('-> Model relies more on non-manipulable features -> MORE ROBUST against evasion')\n",
691
+ "else:\n",
692
+ " print('-> Model relies more on manipulable features -> LESS ROBUST against evasion')"
693
+ ]
694
+ },
695
+ {
696
+ "cell_type": "code",
697
+ "execution_count": null,
698
+ "metadata": {},
699
+ "outputs": [],
700
+ "source": [
701
+ "# Final summary\n",
702
+ "print('\\n' + '='*60)\n",
703
+ "print('FINAL RESULTS SUMMARY')\n",
704
+ "print('='*60)\n",
705
+ "print(f'\\n1. MODEL COMPARISON:')\n",
706
+ "for name in ['mlp', 'lstm', 'cnn1d']:\n",
707
+ " r = results[name]\n",
708
+ " print(f' {name.upper():6s}: F1={r[\"f1\"]:.4f} | ROC-AUC={r[\"roc_auc\"]:.4f} | PR-AUC={r[\"pr_auc\"]:.4f}')\n",
709
+ "\n",
710
+ "print(f'\\n2. EXPLANATION STABILITY (SAFARI):')\n",
711
+ "for eps in epsilons:\n",
712
+ " sr = stability_results[eps]\n",
713
+ " status = 'STABLE' if sr['pcc'] > 0.6 else 'UNSTABLE'\n",
714
+ " print(f' eps={eps}: SENS_MAX={sr[\"sens_max\"]:.4f} | PCC={sr[\"pcc\"]:.4f} [{status}]')\n",
715
+ "print(f' LIME: Spearman={np.mean(lime_corrs):.4f} [{\"STABLE\" if np.mean(lime_corrs) > 0.6 else \"UNSTABLE\"}]')\n",
716
+ "\n",
717
+ "print(f'\\n3. FAITHFULNESS:')\n",
718
+ "for k in [3, 5, 10]:\n",
719
+ " print(f' Top-{k}: confidence drop = {np.mean(faith_results[k]):.4f}')\n",
720
+ "\n",
721
+ "print(f'\\n4. SECURITY: Top features manipulability = {manip_count}')\n",
722
+ "print('\\nDone!')"
723
+ ]
724
+ }
725
+ ],
726
+ "metadata": {
727
+ "kernelspec": {
728
+ "display_name": "Python 3",
729
+ "language": "python",
730
+ "name": "python3"
731
+ },
732
+ "language_info": {
733
+ "name": "python",
734
+ "version": "3.10.0"
735
+ },
736
+ "accelerator": "GPU",
737
+ "colab": {
738
+ "gpuType": "T4"
739
+ }
740
+ },
741
+ "nbformat": 4,
742
+ "nbformat_minor": 4
743
+ }