pavanmutha commited on
Commit
42583bc
·
verified ·
1 Parent(s): e0eece7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -8
app.py CHANGED
@@ -344,13 +344,16 @@ def format_insights(insights, visuals):
344
  ])
345
 
346
 
347
-
348
-
349
-
 
 
 
350
 
351
  def compare_models():
352
  import seaborn as sns
353
- from sklearn.model_selection import cross_val_predict
354
 
355
  if df_global is None:
356
  return pd.DataFrame({"Error": ["Please upload and preprocess a dataset first."]}), None
@@ -360,22 +363,37 @@ def compare_models():
360
  X = df_global.drop(target, axis=1)
361
  y = df_global[target]
362
 
 
363
  if y.dtype == 'object':
364
  y = LabelEncoder().fit_transform(y)
365
 
 
 
 
 
 
366
  models = {
367
  "RandomForest": RandomForestClassifier(),
368
  "LogisticRegression": LogisticRegression(max_iter=1000),
369
- "GradientBoosting": GradientBoostingClassifier()
 
370
  }
371
 
 
 
 
 
 
 
 
 
372
  results = []
373
  for name, model in models.items():
374
  # Cross-validation scores
375
- scores = cross_val_score(model, X, y, cv=5)
376
-
377
  # Cross-validated predictions for metrics
378
- y_pred = cross_val_predict(model, X, y, cv=5)
379
 
380
  metrics = {
381
  "Model": name,
@@ -385,6 +403,7 @@ def compare_models():
385
  "Precision": precision_score(y, y_pred, average="weighted", zero_division=0),
386
  "Recall": recall_score(y, y_pred, average="weighted", zero_division=0),
387
  }
 
388
  if wandb.run is None:
389
  wandb.init(project="model_comparison", name="compare_models", reinit=True)
390
  wandb.log({f"{name}_{k.replace(' ', '_').lower()}": v for k, v in metrics.items() if isinstance(v, (float, int))})
@@ -405,6 +424,10 @@ def compare_models():
405
 
406
  return results_df, plot_path
407
 
 
 
 
 
408
 
409
  # 1. prepare_data should come first
410
  def prepare_data(df):
 
344
  ])
345
 
346
 
347
+ from sklearn.model_selection import StratifiedKFold, GridSearchCV
348
+ from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, VotingClassifier
349
+ from sklearn.linear_model import LogisticRegression
350
+ from sklearn.preprocessing import StandardScaler
351
+ from sklearn.metrics import f1_score, precision_score, recall_score
352
+ import optuna
353
 
354
  def compare_models():
355
  import seaborn as sns
356
+ from sklearn.model_selection import cross_val_predict, cross_val_score
357
 
358
  if df_global is None:
359
  return pd.DataFrame({"Error": ["Please upload and preprocess a dataset first."]}), None
 
363
  X = df_global.drop(target, axis=1)
364
  y = df_global[target]
365
 
366
+ # If the target is categorical, encode it
367
  if y.dtype == 'object':
368
  y = LabelEncoder().fit_transform(y)
369
 
370
+ # Scale features for models like Logistic Regression
371
+ scaler = StandardScaler()
372
+ X_scaled = scaler.fit_transform(X)
373
+
374
+ # Define models
375
  models = {
376
  "RandomForest": RandomForestClassifier(),
377
  "LogisticRegression": LogisticRegression(max_iter=1000),
378
+ "GradientBoosting": GradientBoostingClassifier(),
379
+ # Consider adding more models like XGBoost
380
  }
381
 
382
+ # Optionally, define an ensemble method
383
+ ensemble_model = VotingClassifier(estimators=[('rf', RandomForestClassifier()),
384
+ ('lr', LogisticRegression(max_iter=1000)),
385
+ ('gb', GradientBoostingClassifier())], voting='hard')
386
+
387
+ # Adding the ensemble model to the list
388
+ models["Voting Classifier"] = ensemble_model
389
+
390
  results = []
391
  for name, model in models.items():
392
  # Cross-validation scores
393
+ scores = cross_val_score(model, X_scaled, y, cv=5)
394
+
395
  # Cross-validated predictions for metrics
396
+ y_pred = cross_val_predict(model, X_scaled, y, cv=5)
397
 
398
  metrics = {
399
  "Model": name,
 
403
  "Precision": precision_score(y, y_pred, average="weighted", zero_division=0),
404
  "Recall": recall_score(y, y_pred, average="weighted", zero_division=0),
405
  }
406
+ # Log results to WandB
407
  if wandb.run is None:
408
  wandb.init(project="model_comparison", name="compare_models", reinit=True)
409
  wandb.log({f"{name}_{k.replace(' ', '_').lower()}": v for k, v in metrics.items() if isinstance(v, (float, int))})
 
424
 
425
  return results_df, plot_path
426
 
427
+
428
+
429
+
430
+
431
 
432
  # 1. prepare_data should come first
433
  def prepare_data(df):