Update app.py
Browse files
app.py
CHANGED
|
@@ -118,21 +118,19 @@ def explainability(_):
|
|
| 118 |
|
| 119 |
shap_path = None
|
| 120 |
if isinstance(shap_values, list):
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
if shap_path is None:
|
| 129 |
-
shap_path = class_path
|
| 130 |
else:
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
|
| 137 |
lime_explainer = lime.lime_tabular.LimeTabularExplainer(
|
| 138 |
X_train.values,
|
|
|
|
| 118 |
|
| 119 |
shap_path = None
|
| 120 |
if isinstance(shap_values, list):
|
| 121 |
+
class_idx = 0 # or dynamically pick class with most samples
|
| 122 |
+
shap.summary_plot(shap_values[class_idx], X_test, show=False)
|
| 123 |
+
shap_path = f"./shap_class_{class_idx}.png"
|
| 124 |
+
plt.title(f"SHAP Summary - Class {class_idx}")
|
| 125 |
+
plt.savefig(shap_path)
|
| 126 |
+
wandb.log({f"shap_class_{class_idx}": wandb.Image(shap_path)})
|
| 127 |
+
plt.clf()
|
|
|
|
|
|
|
| 128 |
else:
|
| 129 |
+
shap.summary_plot(shap_values, X_test, show=False)
|
| 130 |
+
shap_path = "./shap_plot.png"
|
| 131 |
+
plt.savefig(shap_path)
|
| 132 |
+
wandb.log({"shap_summary": wandb.Image(shap_path)})
|
| 133 |
+
plt.clf()
|
| 134 |
|
| 135 |
lime_explainer = lime.lime_tabular.LimeTabularExplainer(
|
| 136 |
X_train.values,
|