File size: 2,284 Bytes
153c799
 
 
 
 
88b8e22
153c799
 
 
 
 
88b8e22
153c799
 
 
 
 
 
 
 
88b8e22
 
 
153c799
88b8e22
153c799
 
 
88b8e22
 
 
153c799
 
 
 
88b8e22
153c799
88b8e22
 
153c799
 
 
 
 
 
88b8e22
153c799
 
 
 
88b8e22
153c799
 
d3d65f2
1b80238
153c799
 
 
1b80238
153c799
1b80238
153c799
1b80238
153c799
1b80238
153c799
88b8e22
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from prediction import PredictorModels
import pandas as pd
import shap
import matplotlib.pyplot as plt
import numpy as np

if __name__ == "__main__":
    x_test = pd.read_csv("data/processed/x_test.csv", index_col=0)
    y_test = pd.read_csv("data/processed/y_test.csv", index_col=0)

    predictor = PredictorModels()

    xgb_model = predictor._xgboost
    explainer = shap.TreeExplainer(xgb_model)
    shap_values = explainer.shap_values(x_test)

    # Sum over the output dimension (axis=2) to get overall feature importance
    shap_values_sum = shap_values.sum(axis=2)

    # Compute the mean absolute SHAP values for each feature
    shap_importance = pd.DataFrame(
        {"feature": x_test.columns, "importance": np.abs(shap_values_sum).mean(axis=0)}
    ).sort_values(by="importance", ascending=False)

    # shap_importance.to_csv("shap_importance.csv", index=False)

    # PLOTTING
    plt.figure(figsize=(10, 6))
    bars = plt.barh(
        shap_importance["feature"], shap_importance["importance"], color="skyblue"
    )

    # Add text labels to the bars
    for bar in bars:
        plt.text(
            bar.get_width(),
            bar.get_y() + bar.get_height() / 2,
            f"{bar.get_width():.4f}",
            va="center",
        )

    plt.xlabel("Mean |SHAP value| (Feature Importance)")
    plt.ylabel("Feature")
    plt.title("Overall Feature Importance based on SHAP values")
    plt.gca().invert_yaxis()

    # Save the bar plot to shap_data folder in the data folder
    # plt.savefig("shap_data/shap_feature_importance.png", format='png', dpi=300, bbox_inches='tight')

    # OTHER PLOTS
    """output_features = ['NO2 - Day 1', 'O3 - Day 1', 'NO2 - Day 2', 'O3 - Day 2', 'NO2 - Day 3', 'O3 - Day 3']

    shap_values = explainer.shap_values(x_test)
    n_outputs = shap_values.shape[2]

    for i in range(n_outputs):
        print(f"Generating summary plot for {output_features[i]}")
        plt.figure(figsize=(33, 16))

        shap.summary_plot(shap_values[:, :, i], x_test, plot_type="dot", show=False)

        plt.title(f"SHAP Summary Plot for {output_features[i]}")

        plt.savefig(f"shap_summary_plot_{output_features[i].replace(' ', '_').replace('-', '')}.png", format='png', dpi=300, bbox_inches='tight')

        plt.close()
"""