pavanmutha commited on
Commit
4e0491f
·
verified ·
1 Parent(s): b92ba43

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -14
app.py CHANGED
@@ -118,21 +118,19 @@ def explainability(_):
118
 
119
  shap_path = None
120
  if isinstance(shap_values, list):
121
- for i, class_vals in enumerate(shap_values):
122
- shap.summary_plot(class_vals, X_test, show=False)
123
- class_path = f"./shap_class_{i}.png"
124
- plt.title(f"SHAP Summary - Class {i}")
125
- plt.savefig(class_path)
126
- wandb.log({f"shap_class_{i}": wandb.Image(class_path)})
127
- plt.clf()
128
- if shap_path is None:
129
- shap_path = class_path
130
  else:
131
- shap.summary_plot(shap_values, X_test, show=False)
132
- shap_path = "./shap_plot.png"
133
- plt.savefig(shap_path)
134
- wandb.log({"shap_summary": wandb.Image(shap_path)})
135
- plt.clf()
136
 
137
  lime_explainer = lime.lime_tabular.LimeTabularExplainer(
138
  X_train.values,
 
118
 
119
  shap_path = None
120
  if isinstance(shap_values, list):
121
+ class_idx = 0 # or dynamically pick class with most samples
122
+ shap.summary_plot(shap_values[class_idx], X_test, show=False)
123
+ shap_path = f"./shap_class_{class_idx}.png"
124
+ plt.title(f"SHAP Summary - Class {class_idx}")
125
+ plt.savefig(shap_path)
126
+ wandb.log({f"shap_class_{class_idx}": wandb.Image(shap_path)})
127
+ plt.clf()
 
 
128
  else:
129
+ shap.summary_plot(shap_values, X_test, show=False)
130
+ shap_path = "./shap_plot.png"
131
+ plt.savefig(shap_path)
132
+ wandb.log({"shap_summary": wandb.Image(shap_path)})
133
+ plt.clf()
134
 
135
  lime_explainer = lime.lime_tabular.LimeTabularExplainer(
136
  X_train.values,