03chrisk commited on
Commit
153c799
·
1 Parent(s): 22d48ee

added a python script to compute shap values and make a feature importance plot

Browse files
Files changed (1) hide show
  1. extra_scripts/shap_values.py +75 -0
extra_scripts/shap_values.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from prediction import PredictorModels
4
+ import pandas as pd
5
+ import shap
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+
9
+ if __name__ == "__main__":
10
+
11
+ x_test = pd.read_csv("data/processed/x_test.csv", index_col=0)
12
+ y_test = pd.read_csv("data/processed/y_test.csv", index_col=0)
13
+
14
+ predictor = PredictorModels()
15
+
16
+ xgb_model = predictor._xgboost
17
+ explainer = shap.TreeExplainer(xgb_model)
18
+ shap_values = explainer.shap_values(x_test)
19
+
20
+
21
+ # Sum over the output dimension (axis=2) to get overall feature importance
22
+ shap_values_sum = shap_values.sum(axis=2)
23
+
24
+
25
+ # Compute the mean absolute SHAP values for each feature
26
+ shap_importance = pd.DataFrame({
27
+ 'feature': x_test.columns,
28
+ 'importance': np.abs(shap_values_sum).mean(axis=0)
29
+ }).sort_values(by='importance', ascending=False)
30
+
31
+ # shap_importance.to_csv("shap_importance.csv", index=False)
32
+
33
+
34
+ # PLOTTING
35
+ plt.figure(figsize=(10, 6))
36
+ bars = plt.barh(shap_importance['feature'], shap_importance['importance'], color='skyblue')
37
+
38
+ # Add text labels to the bars
39
+ for bar in bars:
40
+ plt.text(
41
+ bar.get_width(),
42
+ bar.get_y() + bar.get_height() / 2,
43
+ f'{bar.get_width():.4f}',
44
+ va='center'
45
+ )
46
+
47
+ plt.xlabel("Mean |SHAP value| (Feature Importance)")
48
+ plt.ylabel("Feature")
49
+ plt.title("Overall Feature Importance based on SHAP values")
50
+ plt.gca().invert_yaxis()
51
+
52
+ # Save the bar plot to shap_data folder in the data folder
53
+ # plt.savefig("shap_data/shap_feature_importance.png", format='png', dpi=300, bbox_inches='tight')
54
+
55
+
56
+
57
+ # OTHER PLOTS
58
+ '''output_features = ['NO2 - Day 1', 'O3 - Day 1', 'NO2 - Day 2', 'O3 - Day 2', 'NO2 - Day 3', 'O3 - Day 3']
59
+
60
+ shap_values = explainer.shap_values(x_test)
61
+ n_outputs = shap_values.shape[2]
62
+
63
+
64
+ for i in range(n_outputs):
65
+ print(f"Generating summary plot for {output_features[i]}")
66
+ plt.figure(figsize=(33, 16))
67
+
68
+ shap.summary_plot(shap_values[:, :, i], x_test, plot_type="dot", show=False)
69
+
70
+ plt.title(f"SHAP Summary Plot for {output_features[i]}")
71
+
72
+ plt.savefig(f"shap_summary_plot_{output_features[i].replace(' ', '_').replace('-', '')}.png", format='png', dpi=300, bbox_inches='tight')
73
+
74
+ plt.close()
75
+ '''