Update app.py
Browse files
app.py
CHANGED
|
@@ -241,8 +241,10 @@ def train_model(_):
|
|
| 241 |
"n_estimators": trial.suggest_int("n_estimators", 50, 200),
|
| 242 |
"max_depth": trial.suggest_int("max_depth", 3, 10),
|
| 243 |
}
|
| 244 |
-
model = RandomForestClassifier()
|
| 245 |
score = cross_val_score(model, X_train, y_train, cv=3).mean()
|
|
|
|
|
|
|
| 246 |
wandb.log({**params, "cv_score": score})
|
| 247 |
return score
|
| 248 |
|
|
|
|
| 241 |
"n_estimators": trial.suggest_int("n_estimators", 50, 200),
|
| 242 |
"max_depth": trial.suggest_int("max_depth", 3, 10),
|
| 243 |
}
|
| 244 |
+
model = RandomForestClassifier(**params)
|
| 245 |
score = cross_val_score(model, X_train, y_train, cv=3).mean()
|
| 246 |
+
if wandb.run is None:
|
| 247 |
+
wandb.init(project="model_optimization", name=f"optuna_trial_{trial.number}", reinit=True)
|
| 248 |
wandb.log({**params, "cv_score": score})
|
| 249 |
return score
|
| 250 |
|