Spaces:
Sleeping
Sleeping
| from prediction import PredictorModels | |
| import pandas as pd | |
| import shap | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| if __name__ == "__main__": | |
| x_test = pd.read_csv("data/processed/x_test.csv", index_col=0) | |
| y_test = pd.read_csv("data/processed/y_test.csv", index_col=0) | |
| predictor = PredictorModels() | |
| xgb_model = predictor._xgboost | |
| explainer = shap.TreeExplainer(xgb_model) | |
| shap_values = explainer.shap_values(x_test) | |
| # Sum over the output dimension (axis=2) to get overall feature importance | |
| shap_values_sum = shap_values.sum(axis=2) | |
| # Compute the mean absolute SHAP values for each feature | |
| shap_importance = pd.DataFrame( | |
| {"feature": x_test.columns, "importance": np.abs(shap_values_sum).mean(axis=0)} | |
| ).sort_values(by="importance", ascending=False) | |
| # shap_importance.to_csv("shap_importance.csv", index=False) | |
| # PLOTTING | |
| plt.figure(figsize=(10, 6)) | |
| bars = plt.barh( | |
| shap_importance["feature"], shap_importance["importance"], color="skyblue" | |
| ) | |
| # Add text labels to the bars | |
| for bar in bars: | |
| plt.text( | |
| bar.get_width(), | |
| bar.get_y() + bar.get_height() / 2, | |
| f"{bar.get_width():.4f}", | |
| va="center", | |
| ) | |
| plt.xlabel("Mean |SHAP value| (Feature Importance)") | |
| plt.ylabel("Feature") | |
| plt.title("Overall Feature Importance based on SHAP values") | |
| plt.gca().invert_yaxis() | |
| # Save the bar plot to shap_data folder in the data folder | |
| # plt.savefig("shap_data/shap_feature_importance.png", format='png', dpi=300, bbox_inches='tight') | |
| # OTHER PLOTS | |
| """output_features = ['NO2 - Day 1', 'O3 - Day 1', 'NO2 - Day 2', 'O3 - Day 2', 'NO2 - Day 3', 'O3 - Day 3'] | |
| shap_values = explainer.shap_values(x_test) | |
| n_outputs = shap_values.shape[2] | |
| for i in range(n_outputs): | |
| print(f"Generating summary plot for {output_features[i]}") | |
| plt.figure(figsize=(33, 16)) | |
| shap.summary_plot(shap_values[:, :, i], x_test, plot_type="dot", show=False) | |
| plt.title(f"SHAP Summary Plot for {output_features[i]}") | |
| plt.savefig(f"shap_summary_plot_{output_features[i].replace(' ', '_').replace('-', '')}.png", format='png', dpi=300, bbox_inches='tight') | |
| plt.close() | |
| """ | |