pavanmutha commited on
Commit
fd6e2a3
·
verified ·
1 Parent(s): 203c9b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -19
app.py CHANGED
@@ -119,36 +119,34 @@ def explainability(_):
119
  shap_values = explainer.shap_values(X_test)
120
 
121
  try:
122
- if isinstance(shap_values, list): # Multiclass
123
  class_idx = 0
124
- X_shap = pd.DataFrame(
125
- X_test.values[:, :shap_values[class_idx].shape[1]],
126
- columns=X_test.columns[:shap_values[class_idx].shape[1]]
127
- )
128
- shap.summary_plot(shap_values[class_idx], X_shap, show=False)
129
- shap_path = f"./shap_class_{class_idx}.png"
130
- plt.title(f"SHAP Summary - Class {class_idx}")
131
  else:
132
- X_shap = pd.DataFrame(
133
- X_test.values[:, :shap_values.shape[1]],
134
- columns=X_test.columns[:shap_values.shape[1]]
135
- )
136
- shap.summary_plot(shap_values, X_shap, show=False)
137
- shap_path = "./shap_plot.png"
138
 
 
 
 
 
 
 
 
 
 
139
  plt.savefig(shap_path)
140
- if wandb.run is not None:
141
  wandb.log({"shap_summary": wandb.Image(shap_path)})
142
  plt.clf()
143
 
144
  except Exception as e:
145
  shap_path = "./shap_error.png"
146
- print(f"SHAP plotting failed: {e}")
147
  plt.figure(figsize=(6, 3))
148
- plt.text(0.5, 0.5, f"SHAP error:\\n{str(e)}", ha='center', va='center')
149
  plt.axis('off')
150
  plt.savefig(shap_path)
151
- if wandb.run is not None:
152
  wandb.log({"shap_error": wandb.Image(shap_path)})
153
  plt.clf()
154
 
@@ -163,7 +161,7 @@ def explainability(_):
163
  lime_fig = lime_exp.as_pyplot_figure()
164
  lime_path = "./lime_plot.png"
165
  lime_fig.savefig(lime_path)
166
- if wandb.run is not None:
167
  wandb.log({"lime_explanation": wandb.Image(lime_path)})
168
  plt.clf()
169
 
@@ -171,6 +169,7 @@ def explainability(_):
171
 
172
 
173
 
 
174
  with gr.Blocks() as demo:
175
  gr.Markdown("## 📊 AI-Powered Data Analysis with Hyperparameter Optimization")
176
 
 
119
  shap_values = explainer.shap_values(X_test)
120
 
121
  try:
122
+ if isinstance(shap_values, list):
123
  class_idx = 0
124
+ sv = shap_values[class_idx]
 
 
 
 
 
 
125
  else:
126
+ sv = shap_values
 
 
 
 
 
127
 
128
+ # Align number of columns
129
+ if sv.shape[1] != X_test.shape[1]:
130
+ X_test_trimmed = pd.DataFrame(X_test.values[:, :sv.shape[1]], columns=X_test.columns[:sv.shape[1]])
131
+ else:
132
+ X_test_trimmed = X_test
133
+
134
+ shap.summary_plot(sv, X_test_trimmed, show=False)
135
+ shap_path = "./shap_plot.png"
136
+ plt.title("SHAP Summary")
137
  plt.savefig(shap_path)
138
+ if wandb.run:
139
  wandb.log({"shap_summary": wandb.Image(shap_path)})
140
  plt.clf()
141
 
142
  except Exception as e:
143
  shap_path = "./shap_error.png"
144
+ print("SHAP plotting failed:", e)
145
  plt.figure(figsize=(6, 3))
146
+ plt.text(0.5, 0.5, f"SHAP Error:\n{str(e)}", ha='center', va='center')
147
  plt.axis('off')
148
  plt.savefig(shap_path)
149
+ if wandb.run:
150
  wandb.log({"shap_error": wandb.Image(shap_path)})
151
  plt.clf()
152
 
 
161
  lime_fig = lime_exp.as_pyplot_figure()
162
  lime_path = "./lime_plot.png"
163
  lime_fig.savefig(lime_path)
164
+ if wandb.run:
165
  wandb.log({"lime_explanation": wandb.Image(lime_path)})
166
  plt.clf()
167
 
 
169
 
170
 
171
 
172
+
173
  with gr.Blocks() as demo:
174
  gr.Markdown("## 📊 AI-Powered Data Analysis with Hyperparameter Optimization")
175