Update app.py
Browse files
app.py
CHANGED
|
@@ -118,19 +118,21 @@ def explainability(_):
|
|
| 118 |
|
| 119 |
shap_path = None
|
| 120 |
if isinstance(shap_values, list):
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
|
|
|
|
|
|
| 128 |
else:
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
|
| 135 |
lime_explainer = lime.lime_tabular.LimeTabularExplainer(
|
| 136 |
X_train.values,
|
|
|
|
| 118 |
|
| 119 |
shap_path = None
|
| 120 |
if isinstance(shap_values, list):
|
| 121 |
+
for i, class_vals in enumerate(shap_values):
|
| 122 |
+
shap.summary_plot(class_vals, X_test, show=False)
|
| 123 |
+
class_path = f"./shap_class_{i}.png"
|
| 124 |
+
plt.title(f"SHAP Summary - Class {i}")
|
| 125 |
+
plt.savefig(class_path)
|
| 126 |
+
wandb.log({f"shap_class_{i}": wandb.Image(class_path)})
|
| 127 |
+
plt.clf()
|
| 128 |
+
if shap_path is None:
|
| 129 |
+
shap_path = class_path
|
| 130 |
else:
|
| 131 |
+
shap.summary_plot(shap_values, X_test, show=False)
|
| 132 |
+
shap_path = "./shap_plot.png"
|
| 133 |
+
plt.savefig(shap_path)
|
| 134 |
+
wandb.log({"shap_summary": wandb.Image(shap_path)})
|
| 135 |
+
plt.clf()
|
| 136 |
|
| 137 |
lime_explainer = lime.lime_tabular.LimeTabularExplainer(
|
| 138 |
X_train.values,
|