atodorov284
Format code for flake8
d3d65f2
from prediction import PredictorModels
import pandas as pd
import shap
import matplotlib.pyplot as plt
import numpy as np
if __name__ == "__main__":
x_test = pd.read_csv("data/processed/x_test.csv", index_col=0)
y_test = pd.read_csv("data/processed/y_test.csv", index_col=0)
predictor = PredictorModels()
xgb_model = predictor._xgboost
explainer = shap.TreeExplainer(xgb_model)
shap_values = explainer.shap_values(x_test)
# Sum over the output dimension (axis=2) to get overall feature importance
shap_values_sum = shap_values.sum(axis=2)
# Compute the mean absolute SHAP values for each feature
shap_importance = pd.DataFrame(
{"feature": x_test.columns, "importance": np.abs(shap_values_sum).mean(axis=0)}
).sort_values(by="importance", ascending=False)
# shap_importance.to_csv("shap_importance.csv", index=False)
# PLOTTING
plt.figure(figsize=(10, 6))
bars = plt.barh(
shap_importance["feature"], shap_importance["importance"], color="skyblue"
)
# Add text labels to the bars
for bar in bars:
plt.text(
bar.get_width(),
bar.get_y() + bar.get_height() / 2,
f"{bar.get_width():.4f}",
va="center",
)
plt.xlabel("Mean |SHAP value| (Feature Importance)")
plt.ylabel("Feature")
plt.title("Overall Feature Importance based on SHAP values")
plt.gca().invert_yaxis()
# Save the bar plot to shap_data folder in the data folder
# plt.savefig("shap_data/shap_feature_importance.png", format='png', dpi=300, bbox_inches='tight')
# OTHER PLOTS
"""output_features = ['NO2 - Day 1', 'O3 - Day 1', 'NO2 - Day 2', 'O3 - Day 2', 'NO2 - Day 3', 'O3 - Day 3']
shap_values = explainer.shap_values(x_test)
n_outputs = shap_values.shape[2]
for i in range(n_outputs):
print(f"Generating summary plot for {output_features[i]}")
plt.figure(figsize=(33, 16))
shap.summary_plot(shap_values[:, :, i], x_test, plot_type="dot", show=False)
plt.title(f"SHAP Summary Plot for {output_features[i]}")
plt.savefig(f"shap_summary_plot_{output_features[i].replace(' ', '_').replace('-', '')}.png", format='png', dpi=300, bbox_inches='tight')
plt.close()
"""