pavanmutha commited on
Commit
3c9a5e2
·
verified ·
1 Parent(s): b6ce8f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -18
app.py CHANGED
@@ -31,6 +31,7 @@ login(token=hf_token)
31
  model = HfApiModel("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
32
 
33
  df_global = None
 
34
 
35
  def clean_data(df):
36
  df = df.dropna(how='all', axis=1).dropna(how='all', axis=0)
@@ -43,12 +44,20 @@ def clean_data(df):
43
  def upload_file(file):
44
  global df_global
45
  if file is None:
46
- return pd.DataFrame({"Error": ["No file uploaded."]})
47
  ext = os.path.splitext(file.name)[-1]
48
  df = pd.read_csv(file.name) if ext == ".csv" else pd.read_excel(file.name)
49
  df = clean_data(df)
50
  df_global = df
51
- return df.head()
 
 
 
 
 
 
 
 
52
 
53
  def format_analysis_report(raw_output, visuals):
54
  try:
@@ -155,10 +164,14 @@ def analyze_data(csv_file, additional_notes=""):
155
  return format_analysis_report(analysis_result, visuals)
156
 
157
  def compare_models():
 
 
 
158
  if df_global is None:
159
- return "Please upload and preprocess a dataset first."
160
-
161
- target = df_global.columns[-1]
 
162
  X = df_global.drop(target, axis=1)
163
  y = df_global[target]
164
 
@@ -168,32 +181,57 @@ def compare_models():
168
  models = {
169
  "RandomForest": RandomForestClassifier(),
170
  "LogisticRegression": LogisticRegression(max_iter=1000),
171
- "SVC": SVC()
172
  }
173
 
174
  results = []
175
  for name, model in models.items():
 
176
  scores = cross_val_score(model, X, y, cv=5)
177
- results.append({
 
 
 
 
178
  "Model": name,
179
  "CV Mean Accuracy": np.mean(scores),
180
- "CV Std Dev": np.std(scores)
181
- })
182
- wandb.log({f"{name}_cv_mean": np.mean(scores), f"{name}_cv_std": np.std(scores)})
 
 
 
 
 
 
183
 
184
  results_df = pd.DataFrame(results)
185
- return results_df
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  # 1. prepare_data should come first
188
- def prepare_data(df, target_column=None):
 
189
  from sklearn.model_selection import train_test_split
190
 
191
  # If no target column is specified, select the first object column or the last column
192
  if target_column is None:
193
- target_column = df.select_dtypes(include=['object']).columns[0] if len(df.select_dtypes(include=['object']).columns) > 0 else df.columns[-1]
194
 
195
- X = df.drop(columns=[target_column])
196
- y = df[target_column]
197
 
198
  return train_test_split(X, y, test_size=0.3, random_state=42)
199
 
@@ -214,8 +252,10 @@ def train_model(_):
214
  "n_estimators": trial.suggest_int("n_estimators", 50, 200),
215
  "max_depth": trial.suggest_int("max_depth", 3, 10),
216
  }
217
- model = RandomForestClassifier()
218
  score = cross_val_score(model, X_train, y_train, cv=3).mean()
 
 
219
  wandb.log({**params, "cv_score": score})
220
  return score
221
 
@@ -257,7 +297,8 @@ def explainability(_):
257
  import warnings
258
  warnings.filterwarnings("ignore")
259
 
260
- target = df_global.columns[-1]
 
261
  X = df_global.drop(target, axis=1)
262
  y = df_global[target]
263
 
@@ -328,6 +369,16 @@ def explainability(_):
328
 
329
  return shap_path, lime_path
330
 
 
 
 
 
 
 
 
 
 
 
331
  with gr.Blocks() as demo:
332
  gr.Markdown("## 📊 AI-Powered Data Analysis with Hyperparameter Optimization")
333
 
@@ -335,7 +386,12 @@ with gr.Blocks() as demo:
335
  with gr.Column():
336
  file_input = gr.File(label="Upload CSV or Excel", type="filepath")
337
  df_output = gr.DataFrame(label="Cleaned Data Preview")
338
- file_input.change(fn=upload_file, inputs=file_input, outputs=df_output)
 
 
 
 
 
339
 
340
  with gr.Column():
341
  insights_output = gr.HTML(label="Insights from SmolAgent")
@@ -352,8 +408,14 @@ with gr.Blocks() as demo:
352
  shap_img = gr.Image(label="SHAP Summary Plot")
353
  lime_img = gr.Image(label="LIME Explanation")
354
 
 
 
 
 
 
355
  agent_btn.click(fn=analyze_data, inputs=[file_input], outputs=[insights_output, visual_output])
356
  train_btn.click(fn=train_model, inputs=[file_input], outputs=[metrics_output, trials_output])
357
  explain_btn.click(fn=explainability, inputs=[], outputs=[shap_img, lime_img])
 
358
 
359
  demo.launch(debug=True)
 
31
  model = HfApiModel("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
32
 
33
  df_global = None
34
+ target_column_global = None
35
 
36
  def clean_data(df):
37
  df = df.dropna(how='all', axis=1).dropna(how='all', axis=0)
 
44
  def upload_file(file):
45
  global df_global
46
  if file is None:
47
+ return pd.DataFrame({"Error": ["No file uploaded."]}), gr.update(choices=[])
48
  ext = os.path.splitext(file.name)[-1]
49
  df = pd.read_csv(file.name) if ext == ".csv" else pd.read_excel(file.name)
50
  df = clean_data(df)
51
  df_global = df
52
+ return df.head(), gr.update(choices=df.columns.tolist())
53
+
54
+
55
+
56
+ def set_target_column(col_name):
57
+ global target_column_global
58
+ target_column_global = col_name
59
+ return f"✅ Target column set to: {col_name}"
60
+
61
 
62
  def format_analysis_report(raw_output, visuals):
63
  try:
 
164
  return format_analysis_report(analysis_result, visuals)
165
 
166
  def compare_models():
167
+ import seaborn as sns
168
+ from sklearn.model_selection import cross_val_predict
169
+
170
  if df_global is None:
171
+ return pd.DataFrame({"Error": ["Please upload and preprocess a dataset first."]}), None
172
+
173
+ global target_column_global
174
+ target = target_column_global
175
  X = df_global.drop(target, axis=1)
176
  y = df_global[target]
177
 
 
181
  models = {
182
  "RandomForest": RandomForestClassifier(),
183
  "LogisticRegression": LogisticRegression(max_iter=1000),
184
+ "GradientBoosting": GradientBoostingClassifier()
185
  }
186
 
187
  results = []
188
  for name, model in models.items():
189
+ # Cross-validation scores
190
  scores = cross_val_score(model, X, y, cv=5)
191
+
192
+ # Cross-validated predictions for metrics
193
+ y_pred = cross_val_predict(model, X, y, cv=5)
194
+
195
+ metrics = {
196
  "Model": name,
197
  "CV Mean Accuracy": np.mean(scores),
198
+ "CV Std Dev": np.std(scores),
199
+ "F1 Score": f1_score(y, y_pred, average="weighted", zero_division=0),
200
+ "Precision": precision_score(y, y_pred, average="weighted", zero_division=0),
201
+ "Recall": recall_score(y, y_pred, average="weighted", zero_division=0),
202
+ }
203
+ if wandb.run is None:
204
+ wandb.init(project="model_comparison", name="compare_models", reinit=True)
205
+ wandb.log({f"{name}_{k.replace(' ', '_').lower()}": v for k, v in metrics.items() if isinstance(v, (float, int))})
206
+ results.append(metrics)
207
 
208
  results_df = pd.DataFrame(results)
209
+
210
+ # Plotting
211
+ plt.figure(figsize=(8, 5))
212
+ sns.barplot(data=results_df, x="Model", y="CV Mean Accuracy", palette="Blues_d")
213
+ plt.title("Model Comparison (CV Mean Accuracy)")
214
+ plt.ylim(0, 1)
215
+ plt.tight_layout()
216
+
217
+ plot_path = "./model_comparison.png"
218
+ plt.savefig(plot_path)
219
+ plt.close()
220
+
221
+ return results_df, plot_path
222
+
223
 
224
  # 1. prepare_data should come first
225
+ def prepare_data(df):
226
+ global target_column_global
227
  from sklearn.model_selection import train_test_split
228
 
229
  # If no target column is specified, select the first object column or the last column
230
  if target_column is None:
231
+ raise ValueError("Target column not set.")
232
 
233
+ X = df.drop(columns=[target_column_global])
234
+ y = df[target_column_global]
235
 
236
  return train_test_split(X, y, test_size=0.3, random_state=42)
237
 
 
252
  "n_estimators": trial.suggest_int("n_estimators", 50, 200),
253
  "max_depth": trial.suggest_int("max_depth", 3, 10),
254
  }
255
+ model = RandomForestClassifier(**params)
256
  score = cross_val_score(model, X_train, y_train, cv=3).mean()
257
+ if wandb.run is None:
258
+ wandb.init(project="model_optimization", name=f"optuna_trial_{trial.number}", reinit=True)
259
  wandb.log({**params, "cv_score": score})
260
  return score
261
 
 
297
  import warnings
298
  warnings.filterwarnings("ignore")
299
 
300
+ global target_column_global
301
+ target = target_column_global
302
  X = df_global.drop(target, axis=1)
303
  y = df_global[target]
304
 
 
369
 
370
  return shap_path, lime_path
371
 
372
+ # Define this BEFORE the Gradio app layout
373
+
374
+ def update_target_choices():
375
+ global df_global
376
+ if df_global is not None:
377
+ return gr.update(choices=df_global.columns.tolist())
378
+ else:
379
+ return gr.update(choices=[])
380
+
381
+
382
  with gr.Blocks() as demo:
383
  gr.Markdown("## 📊 AI-Powered Data Analysis with Hyperparameter Optimization")
384
 
 
386
  with gr.Column():
387
  file_input = gr.File(label="Upload CSV or Excel", type="filepath")
388
  df_output = gr.DataFrame(label="Cleaned Data Preview")
389
+ target_dropdown = gr.Dropdown(label="Select Target Column", choices=[], interactive=True)
390
+ target_status = gr.Textbox(label="Target Column Status", interactive=False)
391
+
392
+ file_input.change(fn=upload_file, inputs=file_input, outputs=[df_output, target_dropdown])
393
+ #file_input.change(fn=update_target_choices, inputs=[], outputs=target_dropdown)
394
+ target_dropdown.change(fn=set_target_column, inputs=target_dropdown, outputs=target_status)
395
 
396
  with gr.Column():
397
  insights_output = gr.HTML(label="Insights from SmolAgent")
 
408
  shap_img = gr.Image(label="SHAP Summary Plot")
409
  lime_img = gr.Image(label="LIME Explanation")
410
 
411
+ with gr.Row():
412
+ compare_btn = gr.Button("Compare Models (A/B Testing)")
413
+ compare_output = gr.DataFrame(label="Model Comparison (CV + Metrics)")
414
+ compare_img = gr.Image(label="Model Accuracy Plot")
415
+
416
  agent_btn.click(fn=analyze_data, inputs=[file_input], outputs=[insights_output, visual_output])
417
  train_btn.click(fn=train_model, inputs=[file_input], outputs=[metrics_output, trials_output])
418
  explain_btn.click(fn=explainability, inputs=[], outputs=[shap_img, lime_img])
419
+ compare_btn.click(fn=compare_models, inputs=[], outputs=[compare_output, compare_img])
420
 
421
  demo.launch(debug=True)