pavanmutha commited on
Commit
e0eece7
·
verified ·
1 Parent(s): c874a5d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -11
app.py CHANGED
@@ -420,7 +420,6 @@ def prepare_data(df):
420
 
421
  return train_test_split(X, y, test_size=0.3, random_state=42)
422
 
423
-
424
  def train_model(_):
425
  try:
426
  wandb.login(key=os.environ.get("WANDB_API_KEY"))
@@ -439,36 +438,30 @@ def train_model(_):
439
  }
440
  model = RandomForestClassifier(**params)
441
  score = cross_val_score(model, X_train, y_train, cv=3).mean()
442
- if wandb.run is None:
443
- wandb.init(project="model_optimization", name=f"optuna_trial_{trial.number}", reinit=True)
444
  wandb.log({**params, "cv_score": score})
445
- return score
446
 
447
  study = optuna.create_study(direction="maximize")
448
  study.optimize(objective, n_trials=15)
449
 
450
  best_params = study.best_params
451
- model = RandomForestClassifier()
452
  model.fit(X_train, y_train)
453
  y_pred = model.predict(X_test)
454
 
455
-
456
  metrics = {
457
  "accuracy": accuracy_score(y_test, y_pred),
458
  "precision": precision_score(y_test, y_pred, average="weighted", zero_division=0),
459
  "recall": recall_score(y_test, y_pred, average="weighted", zero_division=0),
460
  "f1_score": f1_score(y_test, y_pred, average="weighted", zero_division=0),
461
  }
 
462
  wandb.log(metrics)
463
  wandb_run.finish()
464
 
465
  # Top 7 trials
466
  top_trials = sorted(study.trials, key=lambda x: x.value, reverse=True)[:7]
467
- trial_rows = []
468
- for t in top_trials:
469
- row = t.params.copy()
470
- row["score"] = t.value
471
- trial_rows.append(row)
472
  trials_df = pd.DataFrame(trial_rows)
473
 
474
  return metrics, trials_df
@@ -478,6 +471,7 @@ def train_model(_):
478
  return {}, pd.DataFrame()
479
 
480
 
 
481
  def explainability(_):
482
  import warnings
483
  warnings.filterwarnings("ignore")
 
420
 
421
  return train_test_split(X, y, test_size=0.3, random_state=42)
422
 
 
423
  def train_model(_):
424
  try:
425
  wandb.login(key=os.environ.get("WANDB_API_KEY"))
 
438
  }
439
  model = RandomForestClassifier(**params)
440
  score = cross_val_score(model, X_train, y_train, cv=3).mean()
 
 
441
  wandb.log({**params, "cv_score": score})
442
+ return score # ✅ Must be returned here
443
 
444
  study = optuna.create_study(direction="maximize")
445
  study.optimize(objective, n_trials=15)
446
 
447
  best_params = study.best_params
448
+ model = RandomForestClassifier(**best_params)
449
  model.fit(X_train, y_train)
450
  y_pred = model.predict(X_test)
451
 
 
452
  metrics = {
453
  "accuracy": accuracy_score(y_test, y_pred),
454
  "precision": precision_score(y_test, y_pred, average="weighted", zero_division=0),
455
  "recall": recall_score(y_test, y_pred, average="weighted", zero_division=0),
456
  "f1_score": f1_score(y_test, y_pred, average="weighted", zero_division=0),
457
  }
458
+
459
  wandb.log(metrics)
460
  wandb_run.finish()
461
 
462
  # Top 7 trials
463
  top_trials = sorted(study.trials, key=lambda x: x.value, reverse=True)[:7]
464
+ trial_rows = [dict(**t.params, score=t.value) for t in top_trials]
 
 
 
 
465
  trials_df = pd.DataFrame(trial_rows)
466
 
467
  return metrics, trials_df
 
471
  return {}, pd.DataFrame()
472
 
473
 
474
+
475
  def explainability(_):
476
  import warnings
477
  warnings.filterwarnings("ignore")