pavanmutha commited on
Commit
d1a62b9
·
verified ·
1 Parent(s): cd69066

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -29
app.py CHANGED
@@ -244,35 +244,6 @@ wandb_run.finish()
244
 
245
 
246
 
247
-
248
- def objective(trial):
249
- params = {
250
- "n_estimators": trial.suggest_int("n_estimators", 50, 200),
251
- "max_depth": trial.suggest_int("max_depth", 3, 10),
252
- }
253
- model = RandomForestClassifier(**params)
254
- score = cross_val_score(model, X_train, y_train, cv=3).mean()
255
- wandb.log(params | {"cv_score": score})
256
- return score
257
-
258
- study = optuna.create_study(direction="maximize")
259
- study.optimize(objective, n_trials=15)
260
-
261
- best_params = study.best_params
262
- model = RandomForestClassifier(**best_params)
263
- model.fit(X_train, y_train)
264
- y_pred = model.predict(X_test)
265
-
266
- metrics = {
267
- "accuracy": accuracy_score(y_test, y_pred),
268
- "precision": precision_score(y_test, y_pred, average="weighted", zero_division=0),
269
- "recall": recall_score(y_test, y_pred, average="weighted", zero_division=0),
270
- "f1_score": f1_score(y_test, y_pred, average="weighted", zero_division=0),
271
- }
272
- wandb.log(metrics)
273
- wandb_run.finish()
274
-
275
-
276
  def explainability(_):
277
  import warnings
278
  warnings.filterwarnings("ignore")
 
244
 
245
 
246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  def explainability(_):
248
  import warnings
249
  warnings.filterwarnings("ignore")