Update app.py
Browse files
app.py
CHANGED
|
@@ -102,7 +102,7 @@ def train_model(_):
|
|
| 102 |
def explainability(_):
|
| 103 |
import warnings
|
| 104 |
warnings.filterwarnings("ignore")
|
| 105 |
-
|
| 106 |
target = df_global.columns[-1]
|
| 107 |
X = df_global.drop(target, axis=1)
|
| 108 |
y = df_global[target]
|
|
@@ -115,35 +115,44 @@ def explainability(_):
|
|
| 115 |
model = RandomForestClassifier()
|
| 116 |
model.fit(X_train, y_train)
|
| 117 |
|
| 118 |
-
# SHAP Explainability
|
| 119 |
explainer = shap.TreeExplainer(model)
|
| 120 |
shap_values = explainer.shap_values(X_test)
|
| 121 |
|
| 122 |
try:
|
| 123 |
if isinstance(shap_values, list): # Multiclass
|
| 124 |
class_idx = 0
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
|
|
|
| 129 |
shap_path = f"./shap_class_{class_idx}.png"
|
| 130 |
plt.title(f"SHAP Summary - Class {class_idx}")
|
| 131 |
else:
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
shap_path = "./shap_plot.png"
|
|
|
|
| 134 |
plt.savefig(shap_path)
|
| 135 |
-
wandb.
|
|
|
|
| 136 |
plt.clf()
|
|
|
|
| 137 |
except Exception as e:
|
| 138 |
shap_path = "./shap_error.png"
|
| 139 |
print(f"SHAP plotting failed: {e}")
|
| 140 |
plt.figure(figsize=(6, 3))
|
| 141 |
-
plt.text(0.5, 0.5, f"SHAP error
|
|
|
|
| 142 |
plt.savefig(shap_path)
|
| 143 |
-
wandb.
|
|
|
|
| 144 |
plt.clf()
|
| 145 |
|
| 146 |
-
# LIME
|
| 147 |
lime_explainer = lime.lime_tabular.LimeTabularExplainer(
|
| 148 |
X_train.values,
|
| 149 |
feature_names=X_train.columns.tolist(),
|
|
@@ -154,12 +163,14 @@ def explainability(_):
|
|
| 154 |
lime_fig = lime_exp.as_pyplot_figure()
|
| 155 |
lime_path = "./lime_plot.png"
|
| 156 |
lime_fig.savefig(lime_path)
|
| 157 |
-
wandb.
|
|
|
|
| 158 |
plt.clf()
|
| 159 |
|
| 160 |
return shap_path, lime_path
|
| 161 |
|
| 162 |
|
|
|
|
| 163 |
with gr.Blocks() as demo:
|
| 164 |
gr.Markdown("## 📊 AI-Powered Data Analysis with Hyperparameter Optimization")
|
| 165 |
|
|
|
|
| 102 |
def explainability(_):
|
| 103 |
import warnings
|
| 104 |
warnings.filterwarnings("ignore")
|
| 105 |
+
|
| 106 |
target = df_global.columns[-1]
|
| 107 |
X = df_global.drop(target, axis=1)
|
| 108 |
y = df_global[target]
|
|
|
|
| 115 |
model = RandomForestClassifier()
|
| 116 |
model.fit(X_train, y_train)
|
| 117 |
|
|
|
|
| 118 |
explainer = shap.TreeExplainer(model)
|
| 119 |
shap_values = explainer.shap_values(X_test)
|
| 120 |
|
| 121 |
try:
|
| 122 |
if isinstance(shap_values, list): # Multiclass
|
| 123 |
class_idx = 0
|
| 124 |
+
X_shap = pd.DataFrame(
|
| 125 |
+
X_test.values[:, :shap_values[class_idx].shape[1]],
|
| 126 |
+
columns=X_test.columns[:shap_values[class_idx].shape[1]]
|
| 127 |
+
)
|
| 128 |
+
shap.summary_plot(shap_values[class_idx], X_shap, show=False)
|
| 129 |
shap_path = f"./shap_class_{class_idx}.png"
|
| 130 |
plt.title(f"SHAP Summary - Class {class_idx}")
|
| 131 |
else:
|
| 132 |
+
X_shap = pd.DataFrame(
|
| 133 |
+
X_test.values[:, :shap_values.shape[1]],
|
| 134 |
+
columns=X_test.columns[:shap_values.shape[1]]
|
| 135 |
+
)
|
| 136 |
+
shap.summary_plot(shap_values, X_shap, show=False)
|
| 137 |
shap_path = "./shap_plot.png"
|
| 138 |
+
|
| 139 |
plt.savefig(shap_path)
|
| 140 |
+
if wandb.run is not None:
|
| 141 |
+
wandb.log({"shap_summary": wandb.Image(shap_path)})
|
| 142 |
plt.clf()
|
| 143 |
+
|
| 144 |
except Exception as e:
|
| 145 |
shap_path = "./shap_error.png"
|
| 146 |
print(f"SHAP plotting failed: {e}")
|
| 147 |
plt.figure(figsize=(6, 3))
|
| 148 |
+
plt.text(0.5, 0.5, f"SHAP error:\\n{str(e)}", ha='center', va='center')
|
| 149 |
+
plt.axis('off')
|
| 150 |
plt.savefig(shap_path)
|
| 151 |
+
if wandb.run is not None:
|
| 152 |
+
wandb.log({"shap_error": wandb.Image(shap_path)})
|
| 153 |
plt.clf()
|
| 154 |
|
| 155 |
+
# LIME
|
| 156 |
lime_explainer = lime.lime_tabular.LimeTabularExplainer(
|
| 157 |
X_train.values,
|
| 158 |
feature_names=X_train.columns.tolist(),
|
|
|
|
| 163 |
lime_fig = lime_exp.as_pyplot_figure()
|
| 164 |
lime_path = "./lime_plot.png"
|
| 165 |
lime_fig.savefig(lime_path)
|
| 166 |
+
if wandb.run is not None:
|
| 167 |
+
wandb.log({"lime_explanation": wandb.Image(lime_path)})
|
| 168 |
plt.clf()
|
| 169 |
|
| 170 |
return shap_path, lime_path
|
| 171 |
|
| 172 |
|
| 173 |
+
|
| 174 |
with gr.Blocks() as demo:
|
| 175 |
gr.Markdown("## 📊 AI-Powered Data Analysis with Hyperparameter Optimization")
|
| 176 |
|