cathrica commited on
Commit
e2e801c
·
verified ·
1 Parent(s): f5ee9d3

Fix SHAP compatibility — handle both old list format and new 3D array format

Browse files
Files changed (1) hide show
  1. explainable_ids_full_pipeline.ipynb +28 -27
explainable_ids_full_pipeline.ipynb CHANGED
@@ -59,7 +59,16 @@
59
  "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
60
  "print(f'Device: {DEVICE}')\n",
61
  "if DEVICE.type == 'cuda':\n",
62
- " print(f'GPU: {torch.cuda.get_device_name(0)}')"
 
 
 
 
 
 
 
 
 
63
  ]
64
  },
65
  {
@@ -252,7 +261,6 @@
252
  " optimizer.step()\n",
253
  " total_loss += loss.item() * len(yb)\n",
254
  " \n",
255
- " # Evaluate\n",
256
  " model.eval()\n",
257
  " preds, probs, labels = [], [], []\n",
258
  " with torch.no_grad():\n",
@@ -284,7 +292,6 @@
284
  " \n",
285
  " dt = time.time() - t0\n",
286
  " \n",
287
- " # Load best and final eval\n",
288
  " model.load_state_dict(best_state)\n",
289
  " model.eval()\n",
290
  " preds, probs, labels = [], [], []\n",
@@ -307,9 +314,8 @@
307
  " print('Confusion Matrix:')\n",
308
  " print(confusion_matrix(labels, preds))\n",
309
  " \n",
310
- " return model, {'f1': best_f1, 'roc_auc': roc, 'pr_auc': pr, 'time': dt, 'history': history, 'preds': preds, 'probs': probs, 'labels': labels}\n",
311
  "\n",
312
- "# Train all 3\n",
313
  "models = {}\n",
314
  "results = {}\n",
315
  "for name, cls in [('mlp', MLP_IDS), ('lstm', LSTM_IDS), ('cnn1d', CNN1D_IDS)]:\n",
@@ -373,8 +379,11 @@
373
  "\n",
374
  "explainer = shap.KernelExplainer(predict_fn, X_train[bg_idx])\n",
375
  "print('Computing SHAP values for 150 test samples (this takes a few minutes)...')\n",
376
- "shap_values = explainer.shap_values(X_test[exp_idx], nsamples=200, silent=True)\n",
377
- "print('Done!')"
 
 
 
378
  ]
379
  },
380
  {
@@ -384,7 +393,7 @@
384
  "outputs": [],
385
  "source": [
386
  "# Global feature importance (anomaly class)\n",
387
- "mean_abs_shap = np.abs(shap_values[0]).mean(axis=0)\n",
388
  "feature_importance = sorted(zip(FEATURE_NAMES, mean_abs_shap), key=lambda x: x[1], reverse=True)\n",
389
  "\n",
390
  "print('Top 15 features by mean |SHAP| (anomaly class):')\n",
@@ -399,7 +408,7 @@
399
  "outputs": [],
400
  "source": [
401
  "# SHAP summary plot\n",
402
- "shap.summary_plot(shap_values[0], X_test[exp_idx], feature_names=FEATURE_NAMES, max_display=15)"
403
  ]
404
  },
405
  {
@@ -429,7 +438,10 @@
429
  "pred = predict_fn(X_test[exp_idx[idx:idx+1]])\n",
430
  "print(f'Sample prediction: anomaly={pred[0][0]:.3f}, normal={pred[0][1]:.3f}')\n",
431
  "print(f'True label: {class_names[y_test[exp_idx[idx]]]}')\n",
432
- "shap.force_plot(explainer.expected_value[0], shap_values[0][idx], X_test[exp_idx[idx]], feature_names=FEATURE_NAMES, matplotlib=True)"
 
 
 
433
  ]
434
  },
435
  {
@@ -479,13 +491,11 @@
479
  "# LIME vs SHAP comparison\n",
480
  "fig, axes = plt.subplots(1, 2, figsize=(16, 6))\n",
481
  "\n",
482
- "# SHAP\n",
483
  "top10_shap = feature_importance[:10]\n",
484
  "axes[0].barh(range(10), [v for _, v in top10_shap][::-1], color='steelblue')\n",
485
  "axes[0].set_yticks(range(10)); axes[0].set_yticklabels([f for f, _ in top10_shap][::-1])\n",
486
  "axes[0].set_xlabel('Mean |SHAP value|'); axes[0].set_title('SHAP Top 10')\n",
487
  "\n",
488
- "# LIME\n",
489
  "top10_lime = lime_sorted[:10]\n",
490
  "axes[1].barh(range(10), [v for _, v in top10_lime][::-1], color='coral')\n",
491
  "axes[1].set_yticks(range(10)); axes[1].set_yticklabels([f for f, _ in top10_lime][::-1])\n",
@@ -519,15 +529,15 @@
519
  "def compute_shap_stability(explainer, sample, epsilon, n_perturbs=10):\n",
520
  " \"\"\"Compute SENS_MAX and PCC for one sample.\"\"\"\n",
521
  " rng = np.random.RandomState(SEED)\n",
522
- " base = np.array(explainer.shap_values(sample.reshape(1,-1), nsamples=100, silent=True))\n",
523
- " base = base[0].flatten() if isinstance(base, list) else base.flatten()\n",
524
  " \n",
525
  " max_delta, pccs = 0, []\n",
526
  " for _ in range(n_perturbs):\n",
527
  " noise = rng.uniform(-epsilon, epsilon, sample.shape)\n",
528
  " perturbed = np.clip(sample + noise, 0, 1)\n",
529
- " p_shap = np.array(explainer.shap_values(perturbed.reshape(1,-1), nsamples=100, silent=True))\n",
530
- " p_shap = p_shap[0].flatten() if isinstance(p_shap, list) else p_shap.flatten()\n",
531
  " max_delta = max(max_delta, np.linalg.norm(p_shap - base))\n",
532
  " if np.std(base) > 1e-8 and np.std(p_shap) > 1e-8:\n",
533
  " pccs.append(pearsonr(base, p_shap)[0])\n",
@@ -597,8 +607,8 @@
597
  "\n",
598
  "for idx in stability_idx[:10]:\n",
599
  " sample = X_test[idx]\n",
600
- " sv = np.array(explainer.shap_values(sample.reshape(1,-1), nsamples=100, silent=True))\n",
601
- " sv = sv[0].flatten() if isinstance(sv, list) else sv.flatten()\n",
602
  " \n",
603
  " base_conf = predict_fn(sample.reshape(1,-1))[0]\n",
604
  " pred_cls = np.argmax(base_conf)\n",
@@ -622,13 +632,11 @@
622
  "# Stability summary plot\n",
623
  "fig, axes = plt.subplots(1, 3, figsize=(16, 5))\n",
624
  "\n",
625
- "# SENS_MAX\n",
626
  "eps_list = list(stability_results.keys())\n",
627
  "axes[0].plot(eps_list, [stability_results[e]['sens_max'] for e in eps_list], 'o-', color='steelblue', markersize=8)\n",
628
  "axes[0].set_xlabel('Perturbation epsilon'); axes[0].set_ylabel('SENS_MAX')\n",
629
  "axes[0].set_title('SHAP Sensitivity (lower = more stable)'); axes[0].grid(alpha=0.3)\n",
630
  "\n",
631
- "# PCC\n",
632
  "pcc_vals = [stability_results[e]['pcc'] for e in eps_list]\n",
633
  "colors = ['green' if p > 0.6 else 'red' for p in pcc_vals]\n",
634
  "axes[1].bar(range(len(eps_list)), pcc_vals, color=colors)\n",
@@ -636,7 +644,6 @@
636
  "axes[1].axhline(y=0.6, color='gray', linestyle='--', label='Threshold (0.6)')\n",
637
  "axes[1].set_ylabel('Mean PCC'); axes[1].set_title('SHAP Stability'); axes[1].legend()\n",
638
  "\n",
639
- "# Faithfulness\n",
640
  "ks = list(faith_results.keys())\n",
641
  "axes[2].bar(range(len(ks)), [np.mean(faith_results[k]) for k in ks],\n",
642
  " yerr=[np.std(faith_results[k]) for k in ks], color='coral', capsize=5)\n",
@@ -660,15 +667,10 @@
660
  "metadata": {},
661
  "outputs": [],
662
  "source": [
663
- "# Analyze which top SHAP features are attacker-manipulable\n",
664
  "manipulable = {'src_bytes', 'dst_bytes', 'hot', 'num_failed_logins', 'duration', 'num_compromised',\n",
665
  " 'root_shell', 'su_attempted', 'num_root', 'num_file_creations', 'num_shells', 'num_access_files'}\n",
666
  "partial = {'count', 'srv_count', 'serror_rate', 'srv_serror_rate', 'rerror_rate', 'srv_rerror_rate',\n",
667
  " 'protocol_type', 'flag', 'service'}\n",
668
- "non_manip = {'dst_host_count', 'dst_host_srv_count', 'dst_host_same_srv_rate', 'dst_host_diff_srv_rate',\n",
669
- " 'dst_host_same_src_port_rate', 'dst_host_srv_diff_host_rate', 'dst_host_serror_rate',\n",
670
- " 'dst_host_srv_serror_rate', 'dst_host_rerror_rate', 'dst_host_srv_rerror_rate',\n",
671
- " 'same_srv_rate', 'diff_srv_rate', 'srv_diff_host_rate'}\n",
672
  "\n",
673
  "print('SECURITY ANALYSIS: Top 15 Features by Manipulability')\n",
674
  "print('='*70)\n",
@@ -698,7 +700,6 @@
698
  "metadata": {},
699
  "outputs": [],
700
  "source": [
701
- "# Final summary\n",
702
  "print('\\n' + '='*60)\n",
703
  "print('FINAL RESULTS SUMMARY')\n",
704
  "print('='*60)\n",
 
59
  "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
60
  "print(f'Device: {DEVICE}')\n",
61
  "if DEVICE.type == 'cuda':\n",
62
+ " print(f'GPU: {torch.cuda.get_device_name(0)}')\n",
63
+ "\n",
64
+ "def get_shap_for_class(shap_values, class_idx=0):\n",
65
+ " \"\"\"Handle both old SHAP format (list of arrays) and new format (3D array).\"\"\"\n",
66
+ " if isinstance(shap_values, list):\n",
67
+ " return shap_values[class_idx]\n",
68
+ " elif shap_values.ndim == 3:\n",
69
+ " return shap_values[:, :, class_idx]\n",
70
+ " else:\n",
71
+ " return shap_values"
72
  ]
73
  },
74
  {
 
261
  " optimizer.step()\n",
262
  " total_loss += loss.item() * len(yb)\n",
263
  " \n",
 
264
  " model.eval()\n",
265
  " preds, probs, labels = [], [], []\n",
266
  " with torch.no_grad():\n",
 
292
  " \n",
293
  " dt = time.time() - t0\n",
294
  " \n",
 
295
  " model.load_state_dict(best_state)\n",
296
  " model.eval()\n",
297
  " preds, probs, labels = [], [], []\n",
 
314
  " print('Confusion Matrix:')\n",
315
  " print(confusion_matrix(labels, preds))\n",
316
  " \n",
317
+ " return model, {'f1': best_f1, 'roc_auc': roc, 'pr_auc': pr, 'time': dt, 'history': history}\n",
318
  "\n",
 
319
  "models = {}\n",
320
  "results = {}\n",
321
  "for name, cls in [('mlp', MLP_IDS), ('lstm', LSTM_IDS), ('cnn1d', CNN1D_IDS)]:\n",
 
379
  "\n",
380
  "explainer = shap.KernelExplainer(predict_fn, X_train[bg_idx])\n",
381
  "print('Computing SHAP values for 150 test samples (this takes a few minutes)...')\n",
382
+ "shap_values_raw = explainer.shap_values(X_test[exp_idx], nsamples=200, silent=True)\n",
383
+ "\n",
384
+ "# Get SHAP values for anomaly class (class 0) — works with any SHAP version\n",
385
+ "shap_vals_anomaly = get_shap_for_class(shap_values_raw, class_idx=0)\n",
386
+ "print(f'Done! SHAP values shape: {shap_vals_anomaly.shape}')"
387
  ]
388
  },
389
  {
 
393
  "outputs": [],
394
  "source": [
395
  "# Global feature importance (anomaly class)\n",
396
+ "mean_abs_shap = np.abs(shap_vals_anomaly).mean(axis=0)\n",
397
  "feature_importance = sorted(zip(FEATURE_NAMES, mean_abs_shap), key=lambda x: x[1], reverse=True)\n",
398
  "\n",
399
  "print('Top 15 features by mean |SHAP| (anomaly class):')\n",
 
408
  "outputs": [],
409
  "source": [
410
  "# SHAP summary plot\n",
411
+ "shap.summary_plot(shap_vals_anomaly, X_test[exp_idx], feature_names=FEATURE_NAMES, max_display=15)"
412
  ]
413
  },
414
  {
 
438
  "pred = predict_fn(X_test[exp_idx[idx:idx+1]])\n",
439
  "print(f'Sample prediction: anomaly={pred[0][0]:.3f}, normal={pred[0][1]:.3f}')\n",
440
  "print(f'True label: {class_names[y_test[exp_idx[idx]]]}')\n",
441
+ "\n",
442
+ "ev = explainer.expected_value\n",
443
+ "ev0 = ev[0] if isinstance(ev, (list, np.ndarray)) else ev\n",
444
+ "shap.force_plot(ev0, shap_vals_anomaly[idx], X_test[exp_idx[idx]], feature_names=FEATURE_NAMES, matplotlib=True)"
445
  ]
446
  },
447
  {
 
491
  "# LIME vs SHAP comparison\n",
492
  "fig, axes = plt.subplots(1, 2, figsize=(16, 6))\n",
493
  "\n",
 
494
  "top10_shap = feature_importance[:10]\n",
495
  "axes[0].barh(range(10), [v for _, v in top10_shap][::-1], color='steelblue')\n",
496
  "axes[0].set_yticks(range(10)); axes[0].set_yticklabels([f for f, _ in top10_shap][::-1])\n",
497
  "axes[0].set_xlabel('Mean |SHAP value|'); axes[0].set_title('SHAP Top 10')\n",
498
  "\n",
 
499
  "top10_lime = lime_sorted[:10]\n",
500
  "axes[1].barh(range(10), [v for _, v in top10_lime][::-1], color='coral')\n",
501
  "axes[1].set_yticks(range(10)); axes[1].set_yticklabels([f for f, _ in top10_lime][::-1])\n",
 
529
  "def compute_shap_stability(explainer, sample, epsilon, n_perturbs=10):\n",
530
  " \"\"\"Compute SENS_MAX and PCC for one sample.\"\"\"\n",
531
  " rng = np.random.RandomState(SEED)\n",
532
+ " base_raw = explainer.shap_values(sample.reshape(1,-1), nsamples=100, silent=True)\n",
533
+ " base = get_shap_for_class(base_raw, 0).flatten()\n",
534
  " \n",
535
  " max_delta, pccs = 0, []\n",
536
  " for _ in range(n_perturbs):\n",
537
  " noise = rng.uniform(-epsilon, epsilon, sample.shape)\n",
538
  " perturbed = np.clip(sample + noise, 0, 1)\n",
539
+ " p_raw = explainer.shap_values(perturbed.reshape(1,-1), nsamples=100, silent=True)\n",
540
+ " p_shap = get_shap_for_class(p_raw, 0).flatten()\n",
541
  " max_delta = max(max_delta, np.linalg.norm(p_shap - base))\n",
542
  " if np.std(base) > 1e-8 and np.std(p_shap) > 1e-8:\n",
543
  " pccs.append(pearsonr(base, p_shap)[0])\n",
 
607
  "\n",
608
  "for idx in stability_idx[:10]:\n",
609
  " sample = X_test[idx]\n",
610
+ " sv_raw = explainer.shap_values(sample.reshape(1,-1), nsamples=100, silent=True)\n",
611
+ " sv = get_shap_for_class(sv_raw, 0).flatten()\n",
612
  " \n",
613
  " base_conf = predict_fn(sample.reshape(1,-1))[0]\n",
614
  " pred_cls = np.argmax(base_conf)\n",
 
632
  "# Stability summary plot\n",
633
  "fig, axes = plt.subplots(1, 3, figsize=(16, 5))\n",
634
  "\n",
 
635
  "eps_list = list(stability_results.keys())\n",
636
  "axes[0].plot(eps_list, [stability_results[e]['sens_max'] for e in eps_list], 'o-', color='steelblue', markersize=8)\n",
637
  "axes[0].set_xlabel('Perturbation epsilon'); axes[0].set_ylabel('SENS_MAX')\n",
638
  "axes[0].set_title('SHAP Sensitivity (lower = more stable)'); axes[0].grid(alpha=0.3)\n",
639
  "\n",
 
640
  "pcc_vals = [stability_results[e]['pcc'] for e in eps_list]\n",
641
  "colors = ['green' if p > 0.6 else 'red' for p in pcc_vals]\n",
642
  "axes[1].bar(range(len(eps_list)), pcc_vals, color=colors)\n",
 
644
  "axes[1].axhline(y=0.6, color='gray', linestyle='--', label='Threshold (0.6)')\n",
645
  "axes[1].set_ylabel('Mean PCC'); axes[1].set_title('SHAP Stability'); axes[1].legend()\n",
646
  "\n",
 
647
  "ks = list(faith_results.keys())\n",
648
  "axes[2].bar(range(len(ks)), [np.mean(faith_results[k]) for k in ks],\n",
649
  " yerr=[np.std(faith_results[k]) for k in ks], color='coral', capsize=5)\n",
 
667
  "metadata": {},
668
  "outputs": [],
669
  "source": [
 
670
  "manipulable = {'src_bytes', 'dst_bytes', 'hot', 'num_failed_logins', 'duration', 'num_compromised',\n",
671
  " 'root_shell', 'su_attempted', 'num_root', 'num_file_creations', 'num_shells', 'num_access_files'}\n",
672
  "partial = {'count', 'srv_count', 'serror_rate', 'srv_serror_rate', 'rerror_rate', 'srv_rerror_rate',\n",
673
  " 'protocol_type', 'flag', 'service'}\n",
 
 
 
 
674
  "\n",
675
  "print('SECURITY ANALYSIS: Top 15 Features by Manipulability')\n",
676
  "print('='*70)\n",
 
700
  "metadata": {},
701
  "outputs": [],
702
  "source": [
 
703
  "print('\\n' + '='*60)\n",
704
  "print('FINAL RESULTS SUMMARY')\n",
705
  "print('='*60)\n",