Update app.py
Browse files
app.py
CHANGED
|
@@ -100,6 +100,9 @@ def train_model(_):
|
|
| 100 |
return metrics, top_trials
|
| 101 |
|
| 102 |
def explainability(_):
|
|
|
|
|
|
|
|
|
|
| 103 |
target = df_global.columns[-1]
|
| 104 |
X = df_global.drop(target, axis=1)
|
| 105 |
y = df_global[target]
|
|
@@ -112,41 +115,50 @@ def explainability(_):
|
|
| 112 |
model = RandomForestClassifier()
|
| 113 |
model.fit(X_train, y_train)
|
| 114 |
|
|
|
|
| 115 |
explainer = shap.TreeExplainer(model)
|
| 116 |
shap_values = explainer.shap_values(X_test)
|
| 117 |
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
plt.
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
shap.summary_plot(shap_values, X_test, show=False)
|
| 131 |
-
shap_path = "./shap_plot.png"
|
| 132 |
plt.savefig(shap_path)
|
| 133 |
wandb.log({"shap_summary": wandb.Image(shap_path)})
|
| 134 |
plt.clf()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
|
|
|
| 136 |
lime_explainer = lime.lime_tabular.LimeTabularExplainer(
|
| 137 |
X_train.values,
|
| 138 |
feature_names=X_train.columns.tolist(),
|
| 139 |
-
class_names=[str(
|
| 140 |
mode='classification'
|
| 141 |
)
|
| 142 |
lime_exp = lime_explainer.explain_instance(X_test.iloc[0].values, model.predict_proba)
|
| 143 |
lime_fig = lime_exp.as_pyplot_figure()
|
| 144 |
-
|
| 145 |
-
lime_fig.savefig(
|
| 146 |
-
wandb.log({"lime_explanation": wandb.Image(
|
| 147 |
plt.clf()
|
| 148 |
|
| 149 |
-
return shap_path,
|
|
|
|
| 150 |
|
| 151 |
with gr.Blocks() as demo:
|
| 152 |
gr.Markdown("## 📊 AI-Powered Data Analysis with Hyperparameter Optimization")
|
|
|
|
| 100 |
return metrics, top_trials
|
| 101 |
|
| 102 |
def explainability(_):
|
| 103 |
+
import warnings
|
| 104 |
+
warnings.filterwarnings("ignore")
|
| 105 |
+
|
| 106 |
target = df_global.columns[-1]
|
| 107 |
X = df_global.drop(target, axis=1)
|
| 108 |
y = df_global[target]
|
|
|
|
| 115 |
model = RandomForestClassifier()
|
| 116 |
model.fit(X_train, y_train)
|
| 117 |
|
| 118 |
+
# SHAP Explainability
|
| 119 |
explainer = shap.TreeExplainer(model)
|
| 120 |
shap_values = explainer.shap_values(X_test)
|
| 121 |
|
| 122 |
+
try:
|
| 123 |
+
if isinstance(shap_values, list): # Multiclass
|
| 124 |
+
class_idx = 0
|
| 125 |
+
shap_matrix = shap_values[class_idx] # shape: (samples, features)
|
| 126 |
+
assert shap_matrix.shape[1] == X_test.shape[1], \
|
| 127 |
+
f"SHAP mismatch: shap {shap_matrix.shape} vs X {X_test.shape}"
|
| 128 |
+
shap.summary_plot(shap_matrix, X_test, show=False)
|
| 129 |
+
shap_path = f"./shap_class_{class_idx}.png"
|
| 130 |
+
plt.title(f"SHAP Summary - Class {class_idx}")
|
| 131 |
+
else:
|
| 132 |
+
shap.summary_plot(shap_values, X_test, show=False)
|
| 133 |
+
shap_path = "./shap_plot.png"
|
|
|
|
|
|
|
| 134 |
plt.savefig(shap_path)
|
| 135 |
wandb.log({"shap_summary": wandb.Image(shap_path)})
|
| 136 |
plt.clf()
|
| 137 |
+
except Exception as e:
|
| 138 |
+
shap_path = "./shap_error.png"
|
| 139 |
+
print(f"SHAP plotting failed: {e}")
|
| 140 |
+
plt.figure(figsize=(6, 3))
|
| 141 |
+
plt.text(0.5, 0.5, f"SHAP error:\n{str(e)}", ha='center', va='center')
|
| 142 |
+
plt.savefig(shap_path)
|
| 143 |
+
wandb.log({"shap_error": wandb.Image(shap_path)})
|
| 144 |
+
plt.clf()
|
| 145 |
|
| 146 |
+
# LIME Explainability
|
| 147 |
lime_explainer = lime.lime_tabular.LimeTabularExplainer(
|
| 148 |
X_train.values,
|
| 149 |
feature_names=X_train.columns.tolist(),
|
| 150 |
+
class_names=[str(c) for c in np.unique(y_train)],
|
| 151 |
mode='classification'
|
| 152 |
)
|
| 153 |
lime_exp = lime_explainer.explain_instance(X_test.iloc[0].values, model.predict_proba)
|
| 154 |
lime_fig = lime_exp.as_pyplot_figure()
|
| 155 |
+
lime_path = "./lime_plot.png"
|
| 156 |
+
lime_fig.savefig(lime_path)
|
| 157 |
+
wandb.log({"lime_explanation": wandb.Image(lime_path)})
|
| 158 |
plt.clf()
|
| 159 |
|
| 160 |
+
return shap_path, lime_path
|
| 161 |
+
|
| 162 |
|
| 163 |
with gr.Blocks() as demo:
|
| 164 |
gr.Markdown("## 📊 AI-Powered Data Analysis with Hyperparameter Optimization")
|