Update app.py
Browse files
app.py
CHANGED
|
@@ -31,6 +31,7 @@ login(token=hf_token)
|
|
| 31 |
model = HfApiModel("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
|
| 32 |
|
| 33 |
df_global = None
|
|
|
|
| 34 |
|
| 35 |
def clean_data(df):
|
| 36 |
df = df.dropna(how='all', axis=1).dropna(how='all', axis=0)
|
|
@@ -43,12 +44,20 @@ def clean_data(df):
|
|
| 43 |
def upload_file(file):
|
| 44 |
global df_global
|
| 45 |
if file is None:
|
| 46 |
-
return pd.DataFrame({"Error": ["No file uploaded."]})
|
| 47 |
ext = os.path.splitext(file.name)[-1]
|
| 48 |
df = pd.read_csv(file.name) if ext == ".csv" else pd.read_excel(file.name)
|
| 49 |
df = clean_data(df)
|
| 50 |
df_global = df
|
| 51 |
-
return df.head()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
def format_analysis_report(raw_output, visuals):
|
| 54 |
try:
|
|
@@ -155,10 +164,14 @@ def analyze_data(csv_file, additional_notes=""):
|
|
| 155 |
return format_analysis_report(analysis_result, visuals)
|
| 156 |
|
| 157 |
def compare_models():
|
|
|
|
|
|
|
|
|
|
| 158 |
if df_global is None:
|
| 159 |
-
return "Please upload and preprocess a dataset first."
|
| 160 |
-
|
| 161 |
-
|
|
|
|
| 162 |
X = df_global.drop(target, axis=1)
|
| 163 |
y = df_global[target]
|
| 164 |
|
|
@@ -168,32 +181,57 @@ def compare_models():
|
|
| 168 |
models = {
|
| 169 |
"RandomForest": RandomForestClassifier(),
|
| 170 |
"LogisticRegression": LogisticRegression(max_iter=1000),
|
| 171 |
-
"
|
| 172 |
}
|
| 173 |
|
| 174 |
results = []
|
| 175 |
for name, model in models.items():
|
|
|
|
| 176 |
scores = cross_val_score(model, X, y, cv=5)
|
| 177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
"Model": name,
|
| 179 |
"CV Mean Accuracy": np.mean(scores),
|
| 180 |
-
"CV Std Dev": np.std(scores)
|
| 181 |
-
|
| 182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
results_df = pd.DataFrame(results)
|
| 185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
# 1. prepare_data should come first
|
| 188 |
-
def prepare_data(df
|
|
|
|
| 189 |
from sklearn.model_selection import train_test_split
|
| 190 |
|
| 191 |
# If no target column is specified, select the first object column or the last column
|
| 192 |
if target_column is None:
|
| 193 |
-
|
| 194 |
|
| 195 |
-
X = df.drop(columns=[
|
| 196 |
-
y = df[
|
| 197 |
|
| 198 |
return train_test_split(X, y, test_size=0.3, random_state=42)
|
| 199 |
|
|
@@ -214,8 +252,10 @@ def train_model(_):
|
|
| 214 |
"n_estimators": trial.suggest_int("n_estimators", 50, 200),
|
| 215 |
"max_depth": trial.suggest_int("max_depth", 3, 10),
|
| 216 |
}
|
| 217 |
-
model = RandomForestClassifier()
|
| 218 |
score = cross_val_score(model, X_train, y_train, cv=3).mean()
|
|
|
|
|
|
|
| 219 |
wandb.log({**params, "cv_score": score})
|
| 220 |
return score
|
| 221 |
|
|
@@ -257,7 +297,8 @@ def explainability(_):
|
|
| 257 |
import warnings
|
| 258 |
warnings.filterwarnings("ignore")
|
| 259 |
|
| 260 |
-
|
|
|
|
| 261 |
X = df_global.drop(target, axis=1)
|
| 262 |
y = df_global[target]
|
| 263 |
|
|
@@ -328,6 +369,16 @@ def explainability(_):
|
|
| 328 |
|
| 329 |
return shap_path, lime_path
|
| 330 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
with gr.Blocks() as demo:
|
| 332 |
gr.Markdown("## 📊 AI-Powered Data Analysis with Hyperparameter Optimization")
|
| 333 |
|
|
@@ -335,7 +386,12 @@ with gr.Blocks() as demo:
|
|
| 335 |
with gr.Column():
|
| 336 |
file_input = gr.File(label="Upload CSV or Excel", type="filepath")
|
| 337 |
df_output = gr.DataFrame(label="Cleaned Data Preview")
|
| 338 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
|
| 340 |
with gr.Column():
|
| 341 |
insights_output = gr.HTML(label="Insights from SmolAgent")
|
|
@@ -352,8 +408,14 @@ with gr.Blocks() as demo:
|
|
| 352 |
shap_img = gr.Image(label="SHAP Summary Plot")
|
| 353 |
lime_img = gr.Image(label="LIME Explanation")
|
| 354 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
agent_btn.click(fn=analyze_data, inputs=[file_input], outputs=[insights_output, visual_output])
|
| 356 |
train_btn.click(fn=train_model, inputs=[file_input], outputs=[metrics_output, trials_output])
|
| 357 |
explain_btn.click(fn=explainability, inputs=[], outputs=[shap_img, lime_img])
|
|
|
|
| 358 |
|
| 359 |
demo.launch(debug=True)
|
|
|
|
| 31 |
model = HfApiModel("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
|
| 32 |
|
| 33 |
df_global = None
|
| 34 |
+
target_column_global = None
|
| 35 |
|
| 36 |
def clean_data(df):
|
| 37 |
df = df.dropna(how='all', axis=1).dropna(how='all', axis=0)
|
|
|
|
| 44 |
def upload_file(file):
|
| 45 |
global df_global
|
| 46 |
if file is None:
|
| 47 |
+
return pd.DataFrame({"Error": ["No file uploaded."]}), gr.update(choices=[])
|
| 48 |
ext = os.path.splitext(file.name)[-1]
|
| 49 |
df = pd.read_csv(file.name) if ext == ".csv" else pd.read_excel(file.name)
|
| 50 |
df = clean_data(df)
|
| 51 |
df_global = df
|
| 52 |
+
return df.head(), gr.update(choices=df.columns.tolist())
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def set_target_column(col_name):
|
| 57 |
+
global target_column_global
|
| 58 |
+
target_column_global = col_name
|
| 59 |
+
return f"✅ Target column set to: {col_name}"
|
| 60 |
+
|
| 61 |
|
| 62 |
def format_analysis_report(raw_output, visuals):
|
| 63 |
try:
|
|
|
|
| 164 |
return format_analysis_report(analysis_result, visuals)
|
| 165 |
|
| 166 |
def compare_models():
|
| 167 |
+
import seaborn as sns
|
| 168 |
+
from sklearn.model_selection import cross_val_predict
|
| 169 |
+
|
| 170 |
if df_global is None:
|
| 171 |
+
return pd.DataFrame({"Error": ["Please upload and preprocess a dataset first."]}), None
|
| 172 |
+
|
| 173 |
+
global target_column_global
|
| 174 |
+
target = target_column_global
|
| 175 |
X = df_global.drop(target, axis=1)
|
| 176 |
y = df_global[target]
|
| 177 |
|
|
|
|
| 181 |
models = {
|
| 182 |
"RandomForest": RandomForestClassifier(),
|
| 183 |
"LogisticRegression": LogisticRegression(max_iter=1000),
|
| 184 |
+
"GradientBoosting": GradientBoostingClassifier()
|
| 185 |
}
|
| 186 |
|
| 187 |
results = []
|
| 188 |
for name, model in models.items():
|
| 189 |
+
# Cross-validation scores
|
| 190 |
scores = cross_val_score(model, X, y, cv=5)
|
| 191 |
+
|
| 192 |
+
# Cross-validated predictions for metrics
|
| 193 |
+
y_pred = cross_val_predict(model, X, y, cv=5)
|
| 194 |
+
|
| 195 |
+
metrics = {
|
| 196 |
"Model": name,
|
| 197 |
"CV Mean Accuracy": np.mean(scores),
|
| 198 |
+
"CV Std Dev": np.std(scores),
|
| 199 |
+
"F1 Score": f1_score(y, y_pred, average="weighted", zero_division=0),
|
| 200 |
+
"Precision": precision_score(y, y_pred, average="weighted", zero_division=0),
|
| 201 |
+
"Recall": recall_score(y, y_pred, average="weighted", zero_division=0),
|
| 202 |
+
}
|
| 203 |
+
if wandb.run is None:
|
| 204 |
+
wandb.init(project="model_comparison", name="compare_models", reinit=True)
|
| 205 |
+
wandb.log({f"{name}_{k.replace(' ', '_').lower()}": v for k, v in metrics.items() if isinstance(v, (float, int))})
|
| 206 |
+
results.append(metrics)
|
| 207 |
|
| 208 |
results_df = pd.DataFrame(results)
|
| 209 |
+
|
| 210 |
+
# Plotting
|
| 211 |
+
plt.figure(figsize=(8, 5))
|
| 212 |
+
sns.barplot(data=results_df, x="Model", y="CV Mean Accuracy", palette="Blues_d")
|
| 213 |
+
plt.title("Model Comparison (CV Mean Accuracy)")
|
| 214 |
+
plt.ylim(0, 1)
|
| 215 |
+
plt.tight_layout()
|
| 216 |
+
|
| 217 |
+
plot_path = "./model_comparison.png"
|
| 218 |
+
plt.savefig(plot_path)
|
| 219 |
+
plt.close()
|
| 220 |
+
|
| 221 |
+
return results_df, plot_path
|
| 222 |
+
|
| 223 |
|
| 224 |
# 1. prepare_data should come first
|
| 225 |
+
def prepare_data(df):
|
| 226 |
+
global target_column_global
|
| 227 |
from sklearn.model_selection import train_test_split
|
| 228 |
|
| 229 |
# If no target column is specified, select the first object column or the last column
|
| 230 |
if target_column is None:
|
| 231 |
+
raise ValueError("Target column not set.")
|
| 232 |
|
| 233 |
+
X = df.drop(columns=[target_column_global])
|
| 234 |
+
y = df[target_column_global]
|
| 235 |
|
| 236 |
return train_test_split(X, y, test_size=0.3, random_state=42)
|
| 237 |
|
|
|
|
| 252 |
"n_estimators": trial.suggest_int("n_estimators", 50, 200),
|
| 253 |
"max_depth": trial.suggest_int("max_depth", 3, 10),
|
| 254 |
}
|
| 255 |
+
model = RandomForestClassifier(**params)
|
| 256 |
score = cross_val_score(model, X_train, y_train, cv=3).mean()
|
| 257 |
+
if wandb.run is None:
|
| 258 |
+
wandb.init(project="model_optimization", name=f"optuna_trial_{trial.number}", reinit=True)
|
| 259 |
wandb.log({**params, "cv_score": score})
|
| 260 |
return score
|
| 261 |
|
|
|
|
| 297 |
import warnings
|
| 298 |
warnings.filterwarnings("ignore")
|
| 299 |
|
| 300 |
+
global target_column_global
|
| 301 |
+
target = target_column_global
|
| 302 |
X = df_global.drop(target, axis=1)
|
| 303 |
y = df_global[target]
|
| 304 |
|
|
|
|
| 369 |
|
| 370 |
return shap_path, lime_path
|
| 371 |
|
| 372 |
+
# Define this BEFORE the Gradio app layout
|
| 373 |
+
|
| 374 |
+
def update_target_choices():
|
| 375 |
+
global df_global
|
| 376 |
+
if df_global is not None:
|
| 377 |
+
return gr.update(choices=df_global.columns.tolist())
|
| 378 |
+
else:
|
| 379 |
+
return gr.update(choices=[])
|
| 380 |
+
|
| 381 |
+
|
| 382 |
with gr.Blocks() as demo:
|
| 383 |
gr.Markdown("## 📊 AI-Powered Data Analysis with Hyperparameter Optimization")
|
| 384 |
|
|
|
|
| 386 |
with gr.Column():
|
| 387 |
file_input = gr.File(label="Upload CSV or Excel", type="filepath")
|
| 388 |
df_output = gr.DataFrame(label="Cleaned Data Preview")
|
| 389 |
+
target_dropdown = gr.Dropdown(label="Select Target Column", choices=[], interactive=True)
|
| 390 |
+
target_status = gr.Textbox(label="Target Column Status", interactive=False)
|
| 391 |
+
|
| 392 |
+
file_input.change(fn=upload_file, inputs=file_input, outputs=[df_output, target_dropdown])
|
| 393 |
+
#file_input.change(fn=update_target_choices, inputs=[], outputs=target_dropdown)
|
| 394 |
+
target_dropdown.change(fn=set_target_column, inputs=target_dropdown, outputs=target_status)
|
| 395 |
|
| 396 |
with gr.Column():
|
| 397 |
insights_output = gr.HTML(label="Insights from SmolAgent")
|
|
|
|
| 408 |
shap_img = gr.Image(label="SHAP Summary Plot")
|
| 409 |
lime_img = gr.Image(label="LIME Explanation")
|
| 410 |
|
| 411 |
+
with gr.Row():
|
| 412 |
+
compare_btn = gr.Button("Compare Models (A/B Testing)")
|
| 413 |
+
compare_output = gr.DataFrame(label="Model Comparison (CV + Metrics)")
|
| 414 |
+
compare_img = gr.Image(label="Model Accuracy Plot")
|
| 415 |
+
|
| 416 |
agent_btn.click(fn=analyze_data, inputs=[file_input], outputs=[insights_output, visual_output])
|
| 417 |
train_btn.click(fn=train_model, inputs=[file_input], outputs=[metrics_output, trials_output])
|
| 418 |
explain_btn.click(fn=explainability, inputs=[], outputs=[shap_img, lime_img])
|
| 419 |
+
compare_btn.click(fn=compare_models, inputs=[], outputs=[compare_output, compare_img])
|
| 420 |
|
| 421 |
demo.launch(debug=True)
|