pavanmutha commited on
Commit
9a9b1e2
·
verified ·
1 Parent(s): d0adcb7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -125,20 +125,24 @@ def explainability(_):
125
  else:
126
  sv = shap_values
127
 
128
- # Align number of columns
129
- if sv.shape[1] != X_test.shape[1]:
130
- X_test_trimmed = pd.DataFrame(X_test.values[:, :sv.shape[1]], columns=[f"Feature_{i}" for i in range(sv.shape[1])])
 
131
  else:
132
- X_test_trimmed = X_test
133
 
134
- shap.summary_plot(sv, X_test_trimmed, show=False)
 
 
135
  shap_path = "./shap_plot.png"
136
  plt.title("SHAP Summary")
137
  plt.savefig(shap_path)
138
  if wandb.run:
139
- wandb.log({"shap_summary": wandb.Image(shap_path)})
140
  plt.clf()
141
 
 
142
  except Exception as e:
143
  shap_path = "./shap_error.png"
144
  print("SHAP plotting failed:", e)
 
125
  else:
126
  sv = shap_values
127
 
128
+ # Use safe feature names if mismatch, fallback to dummy
129
+ num_features = sv.shape[1]
130
+ if num_features <= X_test.shape[1]:
131
+ feature_names = X_test.columns[:num_features]
132
  else:
133
+ feature_names = [f"Feature_{i}" for i in range(num_features)]
134
 
135
+ X_shap_safe = pd.DataFrame(np.zeros_like(sv), columns=feature_names)
136
+
137
+ shap.summary_plot(sv, X_shap_safe, show=False)
138
  shap_path = "./shap_plot.png"
139
  plt.title("SHAP Summary")
140
  plt.savefig(shap_path)
141
  if wandb.run:
142
+ wandb.log({"shap_summary": wandb.Image(shap_path)})
143
  plt.clf()
144
 
145
+
146
  except Exception as e:
147
  shap_path = "./shap_error.png"
148
  print("SHAP plotting failed:", e)