pavanmutha commited on
Commit
940e8f9
·
verified ·
1 Parent(s): 6b7b353

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -241,8 +241,10 @@ def train_model(_):
241
  "n_estimators": trial.suggest_int("n_estimators", 50, 200),
242
  "max_depth": trial.suggest_int("max_depth", 3, 10),
243
  }
244
- model = RandomForestClassifier()
245
  score = cross_val_score(model, X_train, y_train, cv=3).mean()
 
 
246
  wandb.log({**params, "cv_score": score})
247
  return score
248
 
 
241
  "n_estimators": trial.suggest_int("n_estimators", 50, 200),
242
  "max_depth": trial.suggest_int("max_depth", 3, 10),
243
  }
244
+ model = RandomForestClassifier(**params)
245
  score = cross_val_score(model, X_train, y_train, cv=3).mean()
246
+ if wandb.run is None:
247
+ wandb.init(project="model_optimization", name=f"optuna_trial_{trial.number}", reinit=True)
248
  wandb.log({**params, "cv_score": score})
249
  return score
250