pavanmutha commited on
Commit
3a00281
·
verified ·
1 Parent(s): a1d13d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -19
app.py CHANGED
@@ -100,6 +100,9 @@ def train_model(_):
100
  return metrics, top_trials
101
 
102
  def explainability(_):
 
 
 
103
  target = df_global.columns[-1]
104
  X = df_global.drop(target, axis=1)
105
  y = df_global[target]
@@ -112,41 +115,50 @@ def explainability(_):
112
  model = RandomForestClassifier()
113
  model.fit(X_train, y_train)
114
 
 
115
  explainer = shap.TreeExplainer(model)
116
  shap_values = explainer.shap_values(X_test)
117
 
118
- shap_path = None
119
- if isinstance(shap_values, list):
120
- for i, class_vals in enumerate(shap_values):
121
- shap.summary_plot(class_vals, X_test, show=False)
122
- class_path = f"./shap_class_{i}.png"
123
- plt.title(f"SHAP Summary - Class {i}")
124
- plt.savefig(class_path)
125
- wandb.log({f"shap_class_{i}": wandb.Image(class_path)})
126
- plt.clf()
127
- if shap_path is None:
128
- shap_path = class_path
129
- else:
130
- shap.summary_plot(shap_values, X_test, show=False)
131
- shap_path = "./shap_plot.png"
132
  plt.savefig(shap_path)
133
  wandb.log({"shap_summary": wandb.Image(shap_path)})
134
  plt.clf()
 
 
 
 
 
 
 
 
135
 
 
136
  lime_explainer = lime.lime_tabular.LimeTabularExplainer(
137
  X_train.values,
138
  feature_names=X_train.columns.tolist(),
139
- class_names=[str(label) for label in np.unique(y_train)],
140
  mode='classification'
141
  )
142
  lime_exp = lime_explainer.explain_instance(X_test.iloc[0].values, model.predict_proba)
143
  lime_fig = lime_exp.as_pyplot_figure()
144
- lime_fig_path = "./lime_plot.png"
145
- lime_fig.savefig(lime_fig_path)
146
- wandb.log({"lime_explanation": wandb.Image(lime_fig_path)})
147
  plt.clf()
148
 
149
- return shap_path, lime_fig_path
 
150
 
151
  with gr.Blocks() as demo:
152
  gr.Markdown("## 📊 AI-Powered Data Analysis with Hyperparameter Optimization")
 
100
  return metrics, top_trials
101
 
102
  def explainability(_):
103
+ import warnings
104
+ warnings.filterwarnings("ignore")
105
+
106
  target = df_global.columns[-1]
107
  X = df_global.drop(target, axis=1)
108
  y = df_global[target]
 
115
  model = RandomForestClassifier()
116
  model.fit(X_train, y_train)
117
 
118
+ # SHAP Explainability
119
  explainer = shap.TreeExplainer(model)
120
  shap_values = explainer.shap_values(X_test)
121
 
122
+ try:
123
+ if isinstance(shap_values, list): # Multiclass
124
+ class_idx = 0
125
+ shap_matrix = shap_values[class_idx] # shape: (samples, features)
126
+ assert shap_matrix.shape[1] == X_test.shape[1], \
127
+ f"SHAP mismatch: shap {shap_matrix.shape} vs X {X_test.shape}"
128
+ shap.summary_plot(shap_matrix, X_test, show=False)
129
+ shap_path = f"./shap_class_{class_idx}.png"
130
+ plt.title(f"SHAP Summary - Class {class_idx}")
131
+ else:
132
+ shap.summary_plot(shap_values, X_test, show=False)
133
+ shap_path = "./shap_plot.png"
 
 
134
  plt.savefig(shap_path)
135
  wandb.log({"shap_summary": wandb.Image(shap_path)})
136
  plt.clf()
137
+ except Exception as e:
138
+ shap_path = "./shap_error.png"
139
+ print(f"SHAP plotting failed: {e}")
140
+ plt.figure(figsize=(6, 3))
141
+ plt.text(0.5, 0.5, f"SHAP error:\n{str(e)}", ha='center', va='center')
142
+ plt.savefig(shap_path)
143
+ wandb.log({"shap_error": wandb.Image(shap_path)})
144
+ plt.clf()
145
 
146
+ # LIME Explainability
147
  lime_explainer = lime.lime_tabular.LimeTabularExplainer(
148
  X_train.values,
149
  feature_names=X_train.columns.tolist(),
150
+ class_names=[str(c) for c in np.unique(y_train)],
151
  mode='classification'
152
  )
153
  lime_exp = lime_explainer.explain_instance(X_test.iloc[0].values, model.predict_proba)
154
  lime_fig = lime_exp.as_pyplot_figure()
155
+ lime_path = "./lime_plot.png"
156
+ lime_fig.savefig(lime_path)
157
+ wandb.log({"lime_explanation": wandb.Image(lime_path)})
158
  plt.clf()
159
 
160
+ return shap_path, lime_path
161
+
162
 
163
  with gr.Blocks() as demo:
164
  gr.Markdown("## 📊 AI-Powered Data Analysis with Hyperparameter Optimization")