pavanmutha commited on
Commit
82a455e
·
verified ·
1 Parent(s): 8e3b920

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -29
app.py CHANGED
@@ -192,35 +192,6 @@ def train_model(_):
192
 
193
 
194
 
195
- def objective(trial):
196
- params = {
197
- "n_estimators": trial.suggest_int("n_estimators", 50, 200),
198
- "max_depth": trial.suggest_int("max_depth", 3, 10),
199
- }
200
- model = RandomForestClassifier(**params)
201
- score = cross_val_score(model, X_train, y_train, cv=3).mean() # Now X_train and y_train are defined
202
- wandb.log(params | {"cv_score": score})
203
- return score
204
-
205
- study = optuna.create_study(direction="maximize")
206
- study.optimize(objective, n_trials=15)
207
-
208
- best_params = study.best_params
209
- model = RandomForestClassifier(**best_params)
210
- model.fit(X_train, y_train)
211
- y_pred = model.predict(X_test)
212
-
213
- metrics = {
214
- "accuracy": accuracy_score(y_test, y_pred),
215
- "precision": precision_score(y_test, y_pred, average="weighted", zero_division=0),
216
- "recall": recall_score(y_test, y_pred, average="weighted", zero_division=0),
217
- "f1_score": f1_score(y_test, y_pred, average="weighted", zero_division=0),
218
- }
219
- wandb.log(metrics)
220
- wandb_run.finish()
221
-
222
-
223
-
224
  def explainability(_):
225
  import warnings
226
  warnings.filterwarnings("ignore")
 
192
 
193
 
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  def explainability(_):
196
  import warnings
197
  warnings.filterwarnings("ignore")