Yair - added error handling for NA y_train labels
Browse files- model_trainer.py +11 -1
model_trainer.py
CHANGED
|
@@ -23,7 +23,17 @@ def train_models(X_train, y_train, categorical_columns):
|
|
| 23 |
models["XGBoost"] = xgb
|
| 24 |
print(f"✅ XGBoost trained in {time.time() - start_time:.2f} sec")
|
| 25 |
else:
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
# Train RandomForest
|
| 29 |
start_time = time.time()
|
|
|
|
| 23 |
models["XGBoost"] = xgb
|
| 24 |
print(f"✅ XGBoost trained in {time.time() - start_time:.2f} sec")
|
| 25 |
else:
|
| 26 |
+
x_train_xgboost = X_train[~y_train.isna()]
|
| 27 |
+
y_train_xgboost = y_train.dropna()
|
| 28 |
+
if set(y_train_xgboost.unique()) <= {0, 1}:
|
| 29 |
+
start_time = time.time()
|
| 30 |
+
xgb = XGBClassifier(**XGB_PARAMS)
|
| 31 |
+
xgb.fit(x_train_xgboost, y_train_xgboost)
|
| 32 |
+
models["XGBoost"] = xgb
|
| 33 |
+
print(f"✅ XGBoost trained in {time.time() - start_time:.2f} sec")
|
| 34 |
+
else:
|
| 35 |
+
models["XGBoost"] = None
|
| 36 |
+
print("⚠ XGBoost training skipped due to invalid labels!")
|
| 37 |
|
| 38 |
# Train RandomForest
|
| 39 |
start_time = time.time()
|