Fix SHAP compatibility — handle both old list format and new 3D array format
Browse files
explainable_ids_full_pipeline.ipynb
CHANGED
|
@@ -59,7 +59,16 @@
|
|
| 59 |
"DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
| 60 |
"print(f'Device: {DEVICE}')\n",
|
| 61 |
"if DEVICE.type == 'cuda':\n",
|
| 62 |
-
" print(f'GPU: {torch.cuda.get_device_name(0)}')"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
]
|
| 64 |
},
|
| 65 |
{
|
|
@@ -252,7 +261,6 @@
|
|
| 252 |
" optimizer.step()\n",
|
| 253 |
" total_loss += loss.item() * len(yb)\n",
|
| 254 |
" \n",
|
| 255 |
-
" # Evaluate\n",
|
| 256 |
" model.eval()\n",
|
| 257 |
" preds, probs, labels = [], [], []\n",
|
| 258 |
" with torch.no_grad():\n",
|
|
@@ -284,7 +292,6 @@
|
|
| 284 |
" \n",
|
| 285 |
" dt = time.time() - t0\n",
|
| 286 |
" \n",
|
| 287 |
-
" # Load best and final eval\n",
|
| 288 |
" model.load_state_dict(best_state)\n",
|
| 289 |
" model.eval()\n",
|
| 290 |
" preds, probs, labels = [], [], []\n",
|
|
@@ -307,9 +314,8 @@
|
|
| 307 |
" print('Confusion Matrix:')\n",
|
| 308 |
" print(confusion_matrix(labels, preds))\n",
|
| 309 |
" \n",
|
| 310 |
-
" return model, {'f1': best_f1, 'roc_auc': roc, 'pr_auc': pr, 'time': dt, 'history': history
|
| 311 |
"\n",
|
| 312 |
-
"# Train all 3\n",
|
| 313 |
"models = {}\n",
|
| 314 |
"results = {}\n",
|
| 315 |
"for name, cls in [('mlp', MLP_IDS), ('lstm', LSTM_IDS), ('cnn1d', CNN1D_IDS)]:\n",
|
|
@@ -373,8 +379,11 @@
|
|
| 373 |
"\n",
|
| 374 |
"explainer = shap.KernelExplainer(predict_fn, X_train[bg_idx])\n",
|
| 375 |
"print('Computing SHAP values for 150 test samples (this takes a few minutes)...')\n",
|
| 376 |
-
"
|
| 377 |
-
"
|
|
|
|
|
|
|
|
|
|
| 378 |
]
|
| 379 |
},
|
| 380 |
{
|
|
@@ -384,7 +393,7 @@
|
|
| 384 |
"outputs": [],
|
| 385 |
"source": [
|
| 386 |
"# Global feature importance (anomaly class)\n",
|
| 387 |
-
"mean_abs_shap = np.abs(
|
| 388 |
"feature_importance = sorted(zip(FEATURE_NAMES, mean_abs_shap), key=lambda x: x[1], reverse=True)\n",
|
| 389 |
"\n",
|
| 390 |
"print('Top 15 features by mean |SHAP| (anomaly class):')\n",
|
|
@@ -399,7 +408,7 @@
|
|
| 399 |
"outputs": [],
|
| 400 |
"source": [
|
| 401 |
"# SHAP summary plot\n",
|
| 402 |
-
"shap.summary_plot(
|
| 403 |
]
|
| 404 |
},
|
| 405 |
{
|
|
@@ -429,7 +438,10 @@
|
|
| 429 |
"pred = predict_fn(X_test[exp_idx[idx:idx+1]])\n",
|
| 430 |
"print(f'Sample prediction: anomaly={pred[0][0]:.3f}, normal={pred[0][1]:.3f}')\n",
|
| 431 |
"print(f'True label: {class_names[y_test[exp_idx[idx]]]}')\n",
|
| 432 |
-
"
|
|
|
|
|
|
|
|
|
|
| 433 |
]
|
| 434 |
},
|
| 435 |
{
|
|
@@ -479,13 +491,11 @@
|
|
| 479 |
"# LIME vs SHAP comparison\n",
|
| 480 |
"fig, axes = plt.subplots(1, 2, figsize=(16, 6))\n",
|
| 481 |
"\n",
|
| 482 |
-
"# SHAP\n",
|
| 483 |
"top10_shap = feature_importance[:10]\n",
|
| 484 |
"axes[0].barh(range(10), [v for _, v in top10_shap][::-1], color='steelblue')\n",
|
| 485 |
"axes[0].set_yticks(range(10)); axes[0].set_yticklabels([f for f, _ in top10_shap][::-1])\n",
|
| 486 |
"axes[0].set_xlabel('Mean |SHAP value|'); axes[0].set_title('SHAP Top 10')\n",
|
| 487 |
"\n",
|
| 488 |
-
"# LIME\n",
|
| 489 |
"top10_lime = lime_sorted[:10]\n",
|
| 490 |
"axes[1].barh(range(10), [v for _, v in top10_lime][::-1], color='coral')\n",
|
| 491 |
"axes[1].set_yticks(range(10)); axes[1].set_yticklabels([f for f, _ in top10_lime][::-1])\n",
|
|
@@ -519,15 +529,15 @@
|
|
| 519 |
"def compute_shap_stability(explainer, sample, epsilon, n_perturbs=10):\n",
|
| 520 |
" \"\"\"Compute SENS_MAX and PCC for one sample.\"\"\"\n",
|
| 521 |
" rng = np.random.RandomState(SEED)\n",
|
| 522 |
-
"
|
| 523 |
-
" base =
|
| 524 |
" \n",
|
| 525 |
" max_delta, pccs = 0, []\n",
|
| 526 |
" for _ in range(n_perturbs):\n",
|
| 527 |
" noise = rng.uniform(-epsilon, epsilon, sample.shape)\n",
|
| 528 |
" perturbed = np.clip(sample + noise, 0, 1)\n",
|
| 529 |
-
"
|
| 530 |
-
" p_shap =
|
| 531 |
" max_delta = max(max_delta, np.linalg.norm(p_shap - base))\n",
|
| 532 |
" if np.std(base) > 1e-8 and np.std(p_shap) > 1e-8:\n",
|
| 533 |
" pccs.append(pearsonr(base, p_shap)[0])\n",
|
|
@@ -597,8 +607,8 @@
|
|
| 597 |
"\n",
|
| 598 |
"for idx in stability_idx[:10]:\n",
|
| 599 |
" sample = X_test[idx]\n",
|
| 600 |
-
"
|
| 601 |
-
" sv =
|
| 602 |
" \n",
|
| 603 |
" base_conf = predict_fn(sample.reshape(1,-1))[0]\n",
|
| 604 |
" pred_cls = np.argmax(base_conf)\n",
|
|
@@ -622,13 +632,11 @@
|
|
| 622 |
"# Stability summary plot\n",
|
| 623 |
"fig, axes = plt.subplots(1, 3, figsize=(16, 5))\n",
|
| 624 |
"\n",
|
| 625 |
-
"# SENS_MAX\n",
|
| 626 |
"eps_list = list(stability_results.keys())\n",
|
| 627 |
"axes[0].plot(eps_list, [stability_results[e]['sens_max'] for e in eps_list], 'o-', color='steelblue', markersize=8)\n",
|
| 628 |
"axes[0].set_xlabel('Perturbation epsilon'); axes[0].set_ylabel('SENS_MAX')\n",
|
| 629 |
"axes[0].set_title('SHAP Sensitivity (lower = more stable)'); axes[0].grid(alpha=0.3)\n",
|
| 630 |
"\n",
|
| 631 |
-
"# PCC\n",
|
| 632 |
"pcc_vals = [stability_results[e]['pcc'] for e in eps_list]\n",
|
| 633 |
"colors = ['green' if p > 0.6 else 'red' for p in pcc_vals]\n",
|
| 634 |
"axes[1].bar(range(len(eps_list)), pcc_vals, color=colors)\n",
|
|
@@ -636,7 +644,6 @@
|
|
| 636 |
"axes[1].axhline(y=0.6, color='gray', linestyle='--', label='Threshold (0.6)')\n",
|
| 637 |
"axes[1].set_ylabel('Mean PCC'); axes[1].set_title('SHAP Stability'); axes[1].legend()\n",
|
| 638 |
"\n",
|
| 639 |
-
"# Faithfulness\n",
|
| 640 |
"ks = list(faith_results.keys())\n",
|
| 641 |
"axes[2].bar(range(len(ks)), [np.mean(faith_results[k]) for k in ks],\n",
|
| 642 |
" yerr=[np.std(faith_results[k]) for k in ks], color='coral', capsize=5)\n",
|
|
@@ -660,15 +667,10 @@
|
|
| 660 |
"metadata": {},
|
| 661 |
"outputs": [],
|
| 662 |
"source": [
|
| 663 |
-
"# Analyze which top SHAP features are attacker-manipulable\n",
|
| 664 |
"manipulable = {'src_bytes', 'dst_bytes', 'hot', 'num_failed_logins', 'duration', 'num_compromised',\n",
|
| 665 |
" 'root_shell', 'su_attempted', 'num_root', 'num_file_creations', 'num_shells', 'num_access_files'}\n",
|
| 666 |
"partial = {'count', 'srv_count', 'serror_rate', 'srv_serror_rate', 'rerror_rate', 'srv_rerror_rate',\n",
|
| 667 |
" 'protocol_type', 'flag', 'service'}\n",
|
| 668 |
-
"non_manip = {'dst_host_count', 'dst_host_srv_count', 'dst_host_same_srv_rate', 'dst_host_diff_srv_rate',\n",
|
| 669 |
-
" 'dst_host_same_src_port_rate', 'dst_host_srv_diff_host_rate', 'dst_host_serror_rate',\n",
|
| 670 |
-
" 'dst_host_srv_serror_rate', 'dst_host_rerror_rate', 'dst_host_srv_rerror_rate',\n",
|
| 671 |
-
" 'same_srv_rate', 'diff_srv_rate', 'srv_diff_host_rate'}\n",
|
| 672 |
"\n",
|
| 673 |
"print('SECURITY ANALYSIS: Top 15 Features by Manipulability')\n",
|
| 674 |
"print('='*70)\n",
|
|
@@ -698,7 +700,6 @@
|
|
| 698 |
"metadata": {},
|
| 699 |
"outputs": [],
|
| 700 |
"source": [
|
| 701 |
-
"# Final summary\n",
|
| 702 |
"print('\\n' + '='*60)\n",
|
| 703 |
"print('FINAL RESULTS SUMMARY')\n",
|
| 704 |
"print('='*60)\n",
|
|
|
|
| 59 |
"DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
| 60 |
"print(f'Device: {DEVICE}')\n",
|
| 61 |
"if DEVICE.type == 'cuda':\n",
|
| 62 |
+
" print(f'GPU: {torch.cuda.get_device_name(0)}')\n",
|
| 63 |
+
"\n",
|
| 64 |
+
"def get_shap_for_class(shap_values, class_idx=0):\n",
|
| 65 |
+
" \"\"\"Handle both old SHAP format (list of arrays) and new format (3D array).\"\"\"\n",
|
| 66 |
+
" if isinstance(shap_values, list):\n",
|
| 67 |
+
" return shap_values[class_idx]\n",
|
| 68 |
+
" elif shap_values.ndim == 3:\n",
|
| 69 |
+
" return shap_values[:, :, class_idx]\n",
|
| 70 |
+
" else:\n",
|
| 71 |
+
" return shap_values"
|
| 72 |
]
|
| 73 |
},
|
| 74 |
{
|
|
|
|
| 261 |
" optimizer.step()\n",
|
| 262 |
" total_loss += loss.item() * len(yb)\n",
|
| 263 |
" \n",
|
|
|
|
| 264 |
" model.eval()\n",
|
| 265 |
" preds, probs, labels = [], [], []\n",
|
| 266 |
" with torch.no_grad():\n",
|
|
|
|
| 292 |
" \n",
|
| 293 |
" dt = time.time() - t0\n",
|
| 294 |
" \n",
|
|
|
|
| 295 |
" model.load_state_dict(best_state)\n",
|
| 296 |
" model.eval()\n",
|
| 297 |
" preds, probs, labels = [], [], []\n",
|
|
|
|
| 314 |
" print('Confusion Matrix:')\n",
|
| 315 |
" print(confusion_matrix(labels, preds))\n",
|
| 316 |
" \n",
|
| 317 |
+
" return model, {'f1': best_f1, 'roc_auc': roc, 'pr_auc': pr, 'time': dt, 'history': history}\n",
|
| 318 |
"\n",
|
|
|
|
| 319 |
"models = {}\n",
|
| 320 |
"results = {}\n",
|
| 321 |
"for name, cls in [('mlp', MLP_IDS), ('lstm', LSTM_IDS), ('cnn1d', CNN1D_IDS)]:\n",
|
|
|
|
| 379 |
"\n",
|
| 380 |
"explainer = shap.KernelExplainer(predict_fn, X_train[bg_idx])\n",
|
| 381 |
"print('Computing SHAP values for 150 test samples (this takes a few minutes)...')\n",
|
| 382 |
+
"shap_values_raw = explainer.shap_values(X_test[exp_idx], nsamples=200, silent=True)\n",
|
| 383 |
+
"\n",
|
| 384 |
+
"# Get SHAP values for anomaly class (class 0) — works with any SHAP version\n",
|
| 385 |
+
"shap_vals_anomaly = get_shap_for_class(shap_values_raw, class_idx=0)\n",
|
| 386 |
+
"print(f'Done! SHAP values shape: {shap_vals_anomaly.shape}')"
|
| 387 |
]
|
| 388 |
},
|
| 389 |
{
|
|
|
|
| 393 |
"outputs": [],
|
| 394 |
"source": [
|
| 395 |
"# Global feature importance (anomaly class)\n",
|
| 396 |
+
"mean_abs_shap = np.abs(shap_vals_anomaly).mean(axis=0)\n",
|
| 397 |
"feature_importance = sorted(zip(FEATURE_NAMES, mean_abs_shap), key=lambda x: x[1], reverse=True)\n",
|
| 398 |
"\n",
|
| 399 |
"print('Top 15 features by mean |SHAP| (anomaly class):')\n",
|
|
|
|
| 408 |
"outputs": [],
|
| 409 |
"source": [
|
| 410 |
"# SHAP summary plot\n",
|
| 411 |
+
"shap.summary_plot(shap_vals_anomaly, X_test[exp_idx], feature_names=FEATURE_NAMES, max_display=15)"
|
| 412 |
]
|
| 413 |
},
|
| 414 |
{
|
|
|
|
| 438 |
"pred = predict_fn(X_test[exp_idx[idx:idx+1]])\n",
|
| 439 |
"print(f'Sample prediction: anomaly={pred[0][0]:.3f}, normal={pred[0][1]:.3f}')\n",
|
| 440 |
"print(f'True label: {class_names[y_test[exp_idx[idx]]]}')\n",
|
| 441 |
+
"\n",
|
| 442 |
+
"ev = explainer.expected_value\n",
|
| 443 |
+
"ev0 = ev[0] if isinstance(ev, (list, np.ndarray)) else ev\n",
|
| 444 |
+
"shap.force_plot(ev0, shap_vals_anomaly[idx], X_test[exp_idx[idx]], feature_names=FEATURE_NAMES, matplotlib=True)"
|
| 445 |
]
|
| 446 |
},
|
| 447 |
{
|
|
|
|
| 491 |
"# LIME vs SHAP comparison\n",
|
| 492 |
"fig, axes = plt.subplots(1, 2, figsize=(16, 6))\n",
|
| 493 |
"\n",
|
|
|
|
| 494 |
"top10_shap = feature_importance[:10]\n",
|
| 495 |
"axes[0].barh(range(10), [v for _, v in top10_shap][::-1], color='steelblue')\n",
|
| 496 |
"axes[0].set_yticks(range(10)); axes[0].set_yticklabels([f for f, _ in top10_shap][::-1])\n",
|
| 497 |
"axes[0].set_xlabel('Mean |SHAP value|'); axes[0].set_title('SHAP Top 10')\n",
|
| 498 |
"\n",
|
|
|
|
| 499 |
"top10_lime = lime_sorted[:10]\n",
|
| 500 |
"axes[1].barh(range(10), [v for _, v in top10_lime][::-1], color='coral')\n",
|
| 501 |
"axes[1].set_yticks(range(10)); axes[1].set_yticklabels([f for f, _ in top10_lime][::-1])\n",
|
|
|
|
| 529 |
"def compute_shap_stability(explainer, sample, epsilon, n_perturbs=10):\n",
|
| 530 |
" \"\"\"Compute SENS_MAX and PCC for one sample.\"\"\"\n",
|
| 531 |
" rng = np.random.RandomState(SEED)\n",
|
| 532 |
+
" base_raw = explainer.shap_values(sample.reshape(1,-1), nsamples=100, silent=True)\n",
|
| 533 |
+
" base = get_shap_for_class(base_raw, 0).flatten()\n",
|
| 534 |
" \n",
|
| 535 |
" max_delta, pccs = 0, []\n",
|
| 536 |
" for _ in range(n_perturbs):\n",
|
| 537 |
" noise = rng.uniform(-epsilon, epsilon, sample.shape)\n",
|
| 538 |
" perturbed = np.clip(sample + noise, 0, 1)\n",
|
| 539 |
+
" p_raw = explainer.shap_values(perturbed.reshape(1,-1), nsamples=100, silent=True)\n",
|
| 540 |
+
" p_shap = get_shap_for_class(p_raw, 0).flatten()\n",
|
| 541 |
" max_delta = max(max_delta, np.linalg.norm(p_shap - base))\n",
|
| 542 |
" if np.std(base) > 1e-8 and np.std(p_shap) > 1e-8:\n",
|
| 543 |
" pccs.append(pearsonr(base, p_shap)[0])\n",
|
|
|
|
| 607 |
"\n",
|
| 608 |
"for idx in stability_idx[:10]:\n",
|
| 609 |
" sample = X_test[idx]\n",
|
| 610 |
+
" sv_raw = explainer.shap_values(sample.reshape(1,-1), nsamples=100, silent=True)\n",
|
| 611 |
+
" sv = get_shap_for_class(sv_raw, 0).flatten()\n",
|
| 612 |
" \n",
|
| 613 |
" base_conf = predict_fn(sample.reshape(1,-1))[0]\n",
|
| 614 |
" pred_cls = np.argmax(base_conf)\n",
|
|
|
|
| 632 |
"# Stability summary plot\n",
|
| 633 |
"fig, axes = plt.subplots(1, 3, figsize=(16, 5))\n",
|
| 634 |
"\n",
|
|
|
|
| 635 |
"eps_list = list(stability_results.keys())\n",
|
| 636 |
"axes[0].plot(eps_list, [stability_results[e]['sens_max'] for e in eps_list], 'o-', color='steelblue', markersize=8)\n",
|
| 637 |
"axes[0].set_xlabel('Perturbation epsilon'); axes[0].set_ylabel('SENS_MAX')\n",
|
| 638 |
"axes[0].set_title('SHAP Sensitivity (lower = more stable)'); axes[0].grid(alpha=0.3)\n",
|
| 639 |
"\n",
|
|
|
|
| 640 |
"pcc_vals = [stability_results[e]['pcc'] for e in eps_list]\n",
|
| 641 |
"colors = ['green' if p > 0.6 else 'red' for p in pcc_vals]\n",
|
| 642 |
"axes[1].bar(range(len(eps_list)), pcc_vals, color=colors)\n",
|
|
|
|
| 644 |
"axes[1].axhline(y=0.6, color='gray', linestyle='--', label='Threshold (0.6)')\n",
|
| 645 |
"axes[1].set_ylabel('Mean PCC'); axes[1].set_title('SHAP Stability'); axes[1].legend()\n",
|
| 646 |
"\n",
|
|
|
|
| 647 |
"ks = list(faith_results.keys())\n",
|
| 648 |
"axes[2].bar(range(len(ks)), [np.mean(faith_results[k]) for k in ks],\n",
|
| 649 |
" yerr=[np.std(faith_results[k]) for k in ks], color='coral', capsize=5)\n",
|
|
|
|
| 667 |
"metadata": {},
|
| 668 |
"outputs": [],
|
| 669 |
"source": [
|
|
|
|
| 670 |
"manipulable = {'src_bytes', 'dst_bytes', 'hot', 'num_failed_logins', 'duration', 'num_compromised',\n",
|
| 671 |
" 'root_shell', 'su_attempted', 'num_root', 'num_file_creations', 'num_shells', 'num_access_files'}\n",
|
| 672 |
"partial = {'count', 'srv_count', 'serror_rate', 'srv_serror_rate', 'rerror_rate', 'srv_rerror_rate',\n",
|
| 673 |
" 'protocol_type', 'flag', 'service'}\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 674 |
"\n",
|
| 675 |
"print('SECURITY ANALYSIS: Top 15 Features by Manipulability')\n",
|
| 676 |
"print('='*70)\n",
|
|
|
|
| 700 |
"metadata": {},
|
| 701 |
"outputs": [],
|
| 702 |
"source": [
|
|
|
|
| 703 |
"print('\\n' + '='*60)\n",
|
| 704 |
"print('FINAL RESULTS SUMMARY')\n",
|
| 705 |
"print('='*60)\n",
|