KaiquanMah commited on
Commit
3b7934d
·
verified ·
1 Parent(s): f5755a2

Yair - added error handling for NA y_train labels

Browse files
Files changed (1) hide show
  1. model_trainer.py +11 -1
model_trainer.py CHANGED
@@ -23,7 +23,17 @@ def train_models(X_train, y_train, categorical_columns):
23
  models["XGBoost"] = xgb
24
  print(f"✅ XGBoost trained in {time.time() - start_time:.2f} sec")
25
  else:
26
- print("⚠ XGBoost training skipped due to invalid labels!")
 
 
 
 
 
 
 
 
 
 
27
 
28
  # Train RandomForest
29
  start_time = time.time()
 
23
  models["XGBoost"] = xgb
24
  print(f"✅ XGBoost trained in {time.time() - start_time:.2f} sec")
25
  else:
26
+ x_train_xgboost = X_train[~y_train.isna()]
27
+ y_train_xgboost = y_train.dropna()
28
+ if set(y_train_xgboost.unique()) <= {0, 1}:
29
+ start_time = time.time()
30
+ xgb = XGBClassifier(**XGB_PARAMS)
31
+ xgb.fit(x_train_xgboost, y_train_xgboost)
32
+ models["XGBoost"] = xgb
33
+ print(f"✅ XGBoost trained in {time.time() - start_time:.2f} sec")
34
+ else:
35
+ models["XGBoost"] = None
36
+ print("⚠ XGBoost training skipped due to invalid labels!")
37
 
38
  # Train RandomForest
39
  start_time = time.time()