pavanmutha commited on
Commit
0f6d44a
·
verified ·
1 Parent(s): 82a455e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -0
app.py CHANGED
@@ -190,6 +190,51 @@ def train_model(_):
190
  wandb_run = wandb.init(project="huggingface-data-analysis", name=f"Optuna_Run_{run_counter}", reinit=True)
191
  run_counter += 1
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
 
195
  def explainability(_):
 
190
  wandb_run = wandb.init(project="huggingface-data-analysis", name=f"Optuna_Run_{run_counter}", reinit=True)
191
  run_counter += 1
192
 
193
+
194
+ X_train, X_test, y_train, y_test = prepare_data()
195
+
196
+ def objective(trial):
197
+ params = {
198
+ "n_estimators": trial.suggest_int("n_estimators", 50, 200),
199
+ "max_depth": trial.suggest_int("max_depth", 3, 10),
200
+ }
201
+ model = RandomForestClassifier(**params)
202
+ score = cross_val_score(model, X_train, y_train, cv=3).mean()
203
+ wandb.log({**params, "cv_score": score})
204
+ return score
205
+
206
+ study = optuna.create_study(direction="maximize")
207
+ study.optimize(objective, n_trials=15)
208
+
209
+ best_params = study.best_params
210
+ model = RandomForestClassifier(**best_params)
211
+ model.fit(X_train, y_train)
212
+ y_pred = model.predict(X_test)
213
+
214
+ metrics = {
215
+ "accuracy": accuracy_score(y_test, y_pred),
216
+ "precision": precision_score(y_test, y_pred, average="weighted", zero_division=0),
217
+ "recall": recall_score(y_test, y_pred, average="weighted", zero_division=0),
218
+ "f1_score": f1_score(y_test, y_pred, average="weighted", zero_division=0),
219
+ }
220
+ wandb.log(metrics)
221
+ wandb_run.finish()
222
+
223
+ # Top 7 trials
224
+ top_trials = sorted(study.trials, key=lambda x: x.value, reverse=True)[:7]
225
+ trial_rows = []
226
+ for t in top_trials:
227
+ row = t.params.copy()
228
+ row["score"] = t.value
229
+ trial_rows.append(row)
230
+ trials_df = pd.DataFrame(trial_rows)
231
+
232
+ return metrics, trials_df
233
+
234
+ except Exception as e:
235
+ print(f"Training Error: {e}")
236
+ return {}, pd.DataFrame()
237
+
238
 
239
 
240
  def explainability(_):