cathrica commited on
Commit
d1e780d
·
verified ·
1 Parent(s): 9ea3b1c

Add SHAP + LIME explainability analysis

Browse files
Files changed (1) hide show
  1. explainability/shap_analysis.py +243 -0
explainability/shap_analysis.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SHAP and LIME explainability analysis for trained IDS models.
3
+ """
4
+
5
+ import os
6
+ import sys
7
+ import json
8
+ import numpy as np
9
+ import torch
10
+ import shap
11
+ from lime import lime_tabular
12
+ import matplotlib
13
+ matplotlib.use('Agg')
14
+ import matplotlib.pyplot as plt
15
+
16
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
17
+
18
+ from models.mlp_baseline import MLP_IDS
19
+ from models.lstm_model import LSTM_IDS
20
+ from models.cnn1d_model import CNN1D_IDS
21
+ from data.preprocess import load_preprocessed, FEATURE_NAMES
22
+
23
+ SEED = 42
24
+ np.random.seed(SEED)
25
+ torch.manual_seed(SEED)
26
+
27
+ DEVICE = torch.device('cpu') # SHAP works best on CPU for these models
28
+ RESULTS_DIR = 'results'
29
+ MODELS_DIR = 'saved_models'
30
+ N_BACKGROUND = 100 # Background samples for SHAP
31
+ N_EXPLAIN = 200 # Samples to explain
32
+
33
+
34
+ def load_model(model_class, model_name, num_classes=2):
35
+ """Load trained model."""
36
+ model = model_class(in_dim=41, num_classes=num_classes)
37
+ model.load_state_dict(torch.load(
38
+ os.path.join(MODELS_DIR, f'{model_name}_best.pt'),
39
+ weights_only=True, map_location='cpu'
40
+ ))
41
+ model.eval()
42
+ return model
43
+
44
+
45
+ def model_predict_fn(model, X):
46
+ """Wrapper for LIME compatibility — returns probabilities."""
47
+ with torch.no_grad():
48
+ tensor = torch.FloatTensor(X).to(DEVICE)
49
+ logits = model(tensor)
50
+ probs = torch.softmax(logits, dim=1).numpy()
51
+ return probs
52
+
53
+
54
+ def run_shap_analysis(model, model_name, X_train, X_test, class_names):
55
+ """Compute SHAP values using KernelExplainer (model-agnostic)."""
56
+ print(f"\n--- SHAP Analysis: {model_name} ---")
57
+
58
+ # Background data
59
+ bg_idx = np.random.choice(len(X_train), N_BACKGROUND, replace=False)
60
+ background = X_train[bg_idx]
61
+
62
+ # Samples to explain
63
+ exp_idx = np.random.choice(len(X_test), N_EXPLAIN, replace=False)
64
+ X_explain = X_test[exp_idx]
65
+
66
+ # Create predict function
67
+ def predict_fn(X):
68
+ return model_predict_fn(model, X)
69
+
70
+ # KernelExplainer (model-agnostic, works for all architectures)
71
+ explainer = shap.KernelExplainer(predict_fn, background)
72
+
73
+ print(f" Computing SHAP values for {N_EXPLAIN} samples...")
74
+ shap_values = explainer.shap_values(X_explain, nsamples=200, silent=True)
75
+
76
+ # --- Global Feature Importance ---
77
+ mean_abs_shap = np.abs(shap_values[0]).mean(axis=0)
78
+ feature_importance = list(zip(FEATURE_NAMES, mean_abs_shap))
79
+ feature_importance.sort(key=lambda x: x[1], reverse=True)
80
+
81
+ print(f"\n Top 10 features (by mean |SHAP| for {class_names[0]}):")
82
+ for fname, imp in feature_importance[:10]:
83
+ print(f" {fname:35s}: {imp:.4f}")
84
+
85
+ # --- Save SHAP summary plot ---
86
+ os.makedirs(RESULTS_DIR, exist_ok=True)
87
+
88
+ plt.figure(figsize=(10, 8))
89
+ shap.summary_plot(shap_values[0], X_explain, feature_names=FEATURE_NAMES,
90
+ show=False, max_display=15)
91
+ plt.title(f'SHAP Feature Importance - {model_name.upper()} ({class_names[0]})')
92
+ plt.tight_layout()
93
+ plt.savefig(os.path.join(RESULTS_DIR, f'shap_summary_{model_name}.png'), dpi=150)
94
+ plt.close()
95
+
96
+ # --- Save bar plot ---
97
+ plt.figure(figsize=(10, 6))
98
+ top_features = feature_importance[:15]
99
+ names = [f[0] for f in top_features]
100
+ values = [f[1] for f in top_features]
101
+ plt.barh(range(len(names)), values[::-1], color='steelblue')
102
+ plt.yticks(range(len(names)), names[::-1])
103
+ plt.xlabel('Mean |SHAP value|')
104
+ plt.title(f'Top 15 Features - {model_name.upper()}')
105
+ plt.tight_layout()
106
+ plt.savefig(os.path.join(RESULTS_DIR, f'shap_bar_{model_name}.png'), dpi=150)
107
+ plt.close()
108
+
109
+ return shap_values, feature_importance, exp_idx
110
+
111
+
112
+ def run_lime_analysis(model, model_name, X_train, X_test, class_names, n_instances=20):
113
+ """Run LIME on a subset of test samples."""
114
+ print(f"\n--- LIME Analysis: {model_name} ---")
115
+
116
+ def predict_fn(X):
117
+ return model_predict_fn(model, X)
118
+
119
+ explainer = lime_tabular.LimeTabularExplainer(
120
+ X_train,
121
+ feature_names=FEATURE_NAMES,
122
+ class_names=class_names,
123
+ discretize_continuous=True,
124
+ random_state=SEED
125
+ )
126
+
127
+ lime_results = []
128
+ all_top_features = {}
129
+
130
+ idx_to_explain = np.random.choice(len(X_test), n_instances, replace=False)
131
+
132
+ for i, idx in enumerate(idx_to_explain):
133
+ sample = X_test[idx]
134
+ exp = explainer.explain_instance(sample, predict_fn, num_features=10, top_labels=1)
135
+
136
+ pred_class = np.argmax(predict_fn(sample.reshape(1, -1)))
137
+ feature_weights = exp.as_list(label=pred_class)
138
+
139
+ lime_results.append({
140
+ 'sample_idx': int(idx),
141
+ 'predicted_class': class_names[pred_class],
142
+ 'top_features': [(fw[0], float(fw[1])) for fw in feature_weights]
143
+ })
144
+
145
+ for fw in feature_weights:
146
+ fname = fw[0].split(' ')[0]
147
+ all_top_features[fname] = all_top_features.get(fname, 0) + 1
148
+
149
+ if (i + 1) % 5 == 0:
150
+ print(f" Explained {i+1}/{n_instances} samples")
151
+
152
+ sorted_features = sorted(all_top_features.items(), key=lambda x: x[1], reverse=True)
153
+ print(f"\n Top features by LIME frequency ({n_instances} samples):")
154
+ for fname, count in sorted_features[:10]:
155
+ print(f" {fname:35s}: appears in {count}/{n_instances} explanations")
156
+
157
+ # Save LIME feature frequency plot
158
+ plt.figure(figsize=(10, 6))
159
+ top_lime = sorted_features[:15]
160
+ names = [f[0] for f in top_lime]
161
+ counts = [f[1] for f in top_lime]
162
+ plt.barh(range(len(names)), counts[::-1], color='coral')
163
+ plt.yticks(range(len(names)), names[::-1])
164
+ plt.xlabel(f'Frequency in top-10 (out of {n_instances} samples)')
165
+ plt.title(f'LIME Top Features - {model_name.upper()}')
166
+ plt.tight_layout()
167
+ plt.savefig(os.path.join(RESULTS_DIR, f'lime_frequency_{model_name}.png'), dpi=150)
168
+ plt.close()
169
+
170
+ return lime_results, sorted_features
171
+
172
+
173
+ def compare_shap_lime(shap_importance, lime_frequency, model_name):
174
+ """Compare SHAP vs LIME feature rankings."""
175
+ from scipy.stats import spearmanr
176
+
177
+ shap_features = {f: i for i, (f, _) in enumerate(shap_importance[:20])}
178
+ lime_features = {f: i for i, (f, _) in enumerate(lime_frequency[:20])}
179
+
180
+ common = set(shap_features.keys()) & set(lime_features.keys())
181
+
182
+ if len(common) >= 5:
183
+ shap_ranks = [shap_features[f] for f in common]
184
+ lime_ranks = [lime_features[f] for f in common]
185
+ corr, p_value = spearmanr(shap_ranks, lime_ranks)
186
+ print(f"\n SHAP vs LIME rank correlation ({model_name}):")
187
+ print(f" Common features in top-20: {len(common)}")
188
+ print(f" Spearman correlation: {corr:.4f} (p={p_value:.4f})")
189
+ return {'spearman_corr': float(corr), 'p_value': float(p_value),
190
+ 'n_common': len(common)}
191
+ else:
192
+ print(f" Too few common features ({len(common)}) for correlation")
193
+ return {'n_common': len(common)}
194
+
195
+
196
+ def main():
197
+ X_train, X_test, y_train, y_test, le, scaler, meta = load_preprocessed()
198
+ class_names = meta['class_names']
199
+
200
+ print(f"Data loaded: {X_train.shape} train, {X_test.shape} test")
201
+ print(f"Classes: {class_names}")
202
+
203
+ all_xai_results = {}
204
+
205
+ models_to_analyze = [
206
+ ('mlp', MLP_IDS),
207
+ ('lstm', LSTM_IDS),
208
+ ('cnn1d', CNN1D_IDS),
209
+ ]
210
+
211
+ for model_name, model_class in models_to_analyze:
212
+ model_path = os.path.join(MODELS_DIR, f'{model_name}_best.pt')
213
+ if not os.path.exists(model_path):
214
+ print(f" Skipping {model_name} - no saved model found")
215
+ continue
216
+
217
+ model = load_model(model_class, model_name, num_classes=len(class_names))
218
+
219
+ shap_vals, shap_importance, exp_idx = run_shap_analysis(
220
+ model, model_name, X_train, X_test, class_names
221
+ )
222
+
223
+ lime_results, lime_frequency = run_lime_analysis(
224
+ model, model_name, X_train, X_test, class_names, n_instances=30
225
+ )
226
+
227
+ comparison = compare_shap_lime(shap_importance, lime_frequency, model_name)
228
+
229
+ all_xai_results[model_name] = {
230
+ 'shap_top_features': [(f, float(v)) for f, v in shap_importance[:15]],
231
+ 'lime_top_features': [(f, int(v)) for f, v in lime_frequency[:15]],
232
+ 'shap_vs_lime': comparison,
233
+ }
234
+
235
+ with open(os.path.join(RESULTS_DIR, 'explainability_results.json'), 'w') as f:
236
+ json.dump(all_xai_results, f, indent=2)
237
+
238
+ print(f"\nExplainability analysis complete!")
239
+ print(f"Results saved to {RESULTS_DIR}/")
240
+
241
+
242
+ if __name__ == '__main__':
243
+ main()