Update app.py
Browse files
app.py
CHANGED
|
@@ -244,35 +244,6 @@ wandb_run.finish()
|
|
| 244 |
|
| 245 |
|
| 246 |
|
| 247 |
-
|
| 248 |
-
def objective(trial):
|
| 249 |
-
params = {
|
| 250 |
-
"n_estimators": trial.suggest_int("n_estimators", 50, 200),
|
| 251 |
-
"max_depth": trial.suggest_int("max_depth", 3, 10),
|
| 252 |
-
}
|
| 253 |
-
model = RandomForestClassifier(**params)
|
| 254 |
-
score = cross_val_score(model, X_train, y_train, cv=3).mean()
|
| 255 |
-
wandb.log(params | {"cv_score": score})
|
| 256 |
-
return score
|
| 257 |
-
|
| 258 |
-
study = optuna.create_study(direction="maximize")
|
| 259 |
-
study.optimize(objective, n_trials=15)
|
| 260 |
-
|
| 261 |
-
best_params = study.best_params
|
| 262 |
-
model = RandomForestClassifier(**best_params)
|
| 263 |
-
model.fit(X_train, y_train)
|
| 264 |
-
y_pred = model.predict(X_test)
|
| 265 |
-
|
| 266 |
-
metrics = {
|
| 267 |
-
"accuracy": accuracy_score(y_test, y_pred),
|
| 268 |
-
"precision": precision_score(y_test, y_pred, average="weighted", zero_division=0),
|
| 269 |
-
"recall": recall_score(y_test, y_pred, average="weighted", zero_division=0),
|
| 270 |
-
"f1_score": f1_score(y_test, y_pred, average="weighted", zero_division=0),
|
| 271 |
-
}
|
| 272 |
-
wandb.log(metrics)
|
| 273 |
-
wandb_run.finish()
|
| 274 |
-
|
| 275 |
-
|
| 276 |
def explainability(_):
|
| 277 |
import warnings
|
| 278 |
warnings.filterwarnings("ignore")
|
|
|
|
| 244 |
|
| 245 |
|
| 246 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
def explainability(_):
|
| 248 |
import warnings
|
| 249 |
warnings.filterwarnings("ignore")
|