pavanmutha commited on
Commit
27aba11
·
verified ·
1 Parent(s): 4e0491f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -12
app.py CHANGED
@@ -118,19 +118,21 @@ def explainability(_):
118
 
119
  shap_path = None
120
  if isinstance(shap_values, list):
121
- class_idx = 0 # or dynamically pick class with most samples
122
- shap.summary_plot(shap_values[class_idx], X_test, show=False)
123
- shap_path = f"./shap_class_{class_idx}.png"
124
- plt.title(f"SHAP Summary - Class {class_idx}")
125
- plt.savefig(shap_path)
126
- wandb.log({f"shap_class_{class_idx}": wandb.Image(shap_path)})
127
- plt.clf()
 
 
128
  else:
129
- shap.summary_plot(shap_values, X_test, show=False)
130
- shap_path = "./shap_plot.png"
131
- plt.savefig(shap_path)
132
- wandb.log({"shap_summary": wandb.Image(shap_path)})
133
- plt.clf()
134
 
135
  lime_explainer = lime.lime_tabular.LimeTabularExplainer(
136
  X_train.values,
 
118
 
119
  shap_path = None
120
  if isinstance(shap_values, list):
121
+ for i, class_vals in enumerate(shap_values):
122
+ shap.summary_plot(class_vals, X_test, show=False)
123
+ class_path = f"./shap_class_{i}.png"
124
+ plt.title(f"SHAP Summary - Class {i}")
125
+ plt.savefig(class_path)
126
+ wandb.log({f"shap_class_{i}": wandb.Image(class_path)})
127
+ plt.clf()
128
+ if shap_path is None:
129
+ shap_path = class_path
130
  else:
131
+ shap.summary_plot(shap_values, X_test, show=False)
132
+ shap_path = "./shap_plot.png"
133
+ plt.savefig(shap_path)
134
+ wandb.log({"shap_summary": wandb.Image(shap_path)})
135
+ plt.clf()
136
 
137
  lime_explainer = lime.lime_tabular.LimeTabularExplainer(
138
  X_train.values,