Spaces:
Sleeping
Sleeping
File size: 2,284 Bytes
153c799 88b8e22 153c799 88b8e22 153c799 88b8e22 153c799 88b8e22 153c799 88b8e22 153c799 88b8e22 153c799 88b8e22 153c799 88b8e22 153c799 88b8e22 153c799 d3d65f2 1b80238 153c799 1b80238 153c799 1b80238 153c799 1b80238 153c799 1b80238 153c799 88b8e22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
from prediction import PredictorModels
import pandas as pd
import shap
import matplotlib.pyplot as plt
import numpy as np
if __name__ == "__main__":
x_test = pd.read_csv("data/processed/x_test.csv", index_col=0)
y_test = pd.read_csv("data/processed/y_test.csv", index_col=0)
predictor = PredictorModels()
xgb_model = predictor._xgboost
explainer = shap.TreeExplainer(xgb_model)
shap_values = explainer.shap_values(x_test)
# Sum over the output dimension (axis=2) to get overall feature importance
shap_values_sum = shap_values.sum(axis=2)
# Compute the mean absolute SHAP values for each feature
shap_importance = pd.DataFrame(
{"feature": x_test.columns, "importance": np.abs(shap_values_sum).mean(axis=0)}
).sort_values(by="importance", ascending=False)
# shap_importance.to_csv("shap_importance.csv", index=False)
# PLOTTING
plt.figure(figsize=(10, 6))
bars = plt.barh(
shap_importance["feature"], shap_importance["importance"], color="skyblue"
)
# Add text labels to the bars
for bar in bars:
plt.text(
bar.get_width(),
bar.get_y() + bar.get_height() / 2,
f"{bar.get_width():.4f}",
va="center",
)
plt.xlabel("Mean |SHAP value| (Feature Importance)")
plt.ylabel("Feature")
plt.title("Overall Feature Importance based on SHAP values")
plt.gca().invert_yaxis()
# Save the bar plot to shap_data folder in the data folder
# plt.savefig("shap_data/shap_feature_importance.png", format='png', dpi=300, bbox_inches='tight')
# OTHER PLOTS
"""output_features = ['NO2 - Day 1', 'O3 - Day 1', 'NO2 - Day 2', 'O3 - Day 2', 'NO2 - Day 3', 'O3 - Day 3']
shap_values = explainer.shap_values(x_test)
n_outputs = shap_values.shape[2]
for i in range(n_outputs):
print(f"Generating summary plot for {output_features[i]}")
plt.figure(figsize=(33, 16))
shap.summary_plot(shap_values[:, :, i], x_test, plot_type="dot", show=False)
plt.title(f"SHAP Summary Plot for {output_features[i]}")
plt.savefig(f"shap_summary_plot_{output_features[i].replace(' ', '_').replace('-', '')}.png", format='png', dpi=300, bbox_inches='tight')
plt.close()
"""
|