pavanmutha commited on
Commit
203c9b1
·
verified ·
1 Parent(s): 73620eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -12
app.py CHANGED
@@ -102,7 +102,7 @@ def train_model(_):
102
  def explainability(_):
103
  import warnings
104
  warnings.filterwarnings("ignore")
105
-
106
  target = df_global.columns[-1]
107
  X = df_global.drop(target, axis=1)
108
  y = df_global[target]
@@ -115,35 +115,44 @@ def explainability(_):
115
  model = RandomForestClassifier()
116
  model.fit(X_train, y_train)
117
 
118
- # SHAP Explainability
119
  explainer = shap.TreeExplainer(model)
120
  shap_values = explainer.shap_values(X_test)
121
 
122
  try:
123
  if isinstance(shap_values, list): # Multiclass
124
  class_idx = 0
125
- shap_matrix = shap_values[class_idx] # shape: (samples, features)
126
- assert shap_matrix.shape[1] == X_test.shape[1], \
127
- f"SHAP mismatch: shap {shap_matrix.shape} vs X {X_test.shape}"
128
- shap.summary_plot(shap_matrix, X_test, show=False)
 
129
  shap_path = f"./shap_class_{class_idx}.png"
130
  plt.title(f"SHAP Summary - Class {class_idx}")
131
  else:
132
- shap.summary_plot(shap_values, X_test, show=False)
 
 
 
 
133
  shap_path = "./shap_plot.png"
 
134
  plt.savefig(shap_path)
135
- wandb.log({"shap_summary": wandb.Image(shap_path)})
 
136
  plt.clf()
 
137
  except Exception as e:
138
  shap_path = "./shap_error.png"
139
  print(f"SHAP plotting failed: {e}")
140
  plt.figure(figsize=(6, 3))
141
- plt.text(0.5, 0.5, f"SHAP error:\n{str(e)}", ha='center', va='center')
 
142
  plt.savefig(shap_path)
143
- wandb.log({"shap_error": wandb.Image(shap_path)})
 
144
  plt.clf()
145
 
146
- # LIME Explainability
147
  lime_explainer = lime.lime_tabular.LimeTabularExplainer(
148
  X_train.values,
149
  feature_names=X_train.columns.tolist(),
@@ -154,12 +163,14 @@ def explainability(_):
154
  lime_fig = lime_exp.as_pyplot_figure()
155
  lime_path = "./lime_plot.png"
156
  lime_fig.savefig(lime_path)
157
- wandb.log({"lime_explanation": wandb.Image(lime_path)})
 
158
  plt.clf()
159
 
160
  return shap_path, lime_path
161
 
162
 
 
163
  with gr.Blocks() as demo:
164
  gr.Markdown("## 📊 AI-Powered Data Analysis with Hyperparameter Optimization")
165
 
 
102
  def explainability(_):
103
  import warnings
104
  warnings.filterwarnings("ignore")
105
+
106
  target = df_global.columns[-1]
107
  X = df_global.drop(target, axis=1)
108
  y = df_global[target]
 
115
  model = RandomForestClassifier()
116
  model.fit(X_train, y_train)
117
 
 
118
  explainer = shap.TreeExplainer(model)
119
  shap_values = explainer.shap_values(X_test)
120
 
121
  try:
122
  if isinstance(shap_values, list): # Multiclass
123
  class_idx = 0
124
+ X_shap = pd.DataFrame(
125
+ X_test.values[:, :shap_values[class_idx].shape[1]],
126
+ columns=X_test.columns[:shap_values[class_idx].shape[1]]
127
+ )
128
+ shap.summary_plot(shap_values[class_idx], X_shap, show=False)
129
  shap_path = f"./shap_class_{class_idx}.png"
130
  plt.title(f"SHAP Summary - Class {class_idx}")
131
  else:
132
+ X_shap = pd.DataFrame(
133
+ X_test.values[:, :shap_values.shape[1]],
134
+ columns=X_test.columns[:shap_values.shape[1]]
135
+ )
136
+ shap.summary_plot(shap_values, X_shap, show=False)
137
  shap_path = "./shap_plot.png"
138
+
139
  plt.savefig(shap_path)
140
+ if wandb.run is not None:
141
+ wandb.log({"shap_summary": wandb.Image(shap_path)})
142
  plt.clf()
143
+
144
  except Exception as e:
145
  shap_path = "./shap_error.png"
146
  print(f"SHAP plotting failed: {e}")
147
  plt.figure(figsize=(6, 3))
148
+ plt.text(0.5, 0.5, f"SHAP error:\\n{str(e)}", ha='center', va='center')
149
+ plt.axis('off')
150
  plt.savefig(shap_path)
151
+ if wandb.run is not None:
152
+ wandb.log({"shap_error": wandb.Image(shap_path)})
153
  plt.clf()
154
 
155
+ # LIME
156
  lime_explainer = lime.lime_tabular.LimeTabularExplainer(
157
  X_train.values,
158
  feature_names=X_train.columns.tolist(),
 
163
  lime_fig = lime_exp.as_pyplot_figure()
164
  lime_path = "./lime_plot.png"
165
  lime_fig.savefig(lime_path)
166
+ if wandb.run is not None:
167
+ wandb.log({"lime_explanation": wandb.Image(lime_path)})
168
  plt.clf()
169
 
170
  return shap_path, lime_path
171
 
172
 
173
+
174
  with gr.Blocks() as demo:
175
  gr.Markdown("## 📊 AI-Powered Data Analysis with Hyperparameter Optimization")
176