MultiTaskTox / src /lightgbm_trainer.py
Maximilian Schuh
Added new files and updated learning
759324e
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Optional, Sequence
import joblib
import lightgbm as lgb
import numpy as np
import optuna
import pandas as pd
from sklearn.metrics import roc_auc_score
from tqdm import tqdm
from .constants import TARGET_NAMES
@dataclass
class TaskTrainingOutput:
model: lgb.LGBMClassifier
val_auc: float
best_iteration: int
best_params: Dict
def resolve_n_estimators(training_cfg: Dict) -> Sequence[int]:
"""Normalize the n_estimators config entry into a non-empty list of ints."""
if "n_estimators" in training_cfg:
raw_value = training_cfg["n_estimators"]
elif "boosting_rounds" in training_cfg:
raw_value = training_cfg["boosting_rounds"]
else:
raw_value = [50, 500, 1000]
if isinstance(raw_value, int):
choices = [int(raw_value)]
elif isinstance(raw_value, Sequence) and not isinstance(raw_value, (str, bytes)):
choices = [int(v) for v in raw_value]
else:
raise ValueError("training.n_estimators must be an int or a sequence of ints")
choices = [v for v in choices if v > 0]
if not choices:
raise ValueError("training.n_estimators must contain at least one positive value")
return choices
def _sample_hyperparams(trial: optuna.Trial, base_params: Dict, n_estimators_choices: Sequence[int]) -> Dict:
params = dict(base_params)
params.update(
{
"learning_rate": trial.suggest_float("learning_rate", 1e-3, 0.2, log=True),
"num_leaves": trial.suggest_int("num_leaves", 16, 256, log=True),
"max_depth": trial.suggest_int("max_depth", -1, 12),
"min_child_samples": trial.suggest_int("min_child_samples", 10, 200),
"feature_fraction": trial.suggest_float("feature_fraction", 0.5, 1.0),
"bagging_fraction": trial.suggest_float("bagging_fraction", 0.5, 1.0),
"bagging_freq": trial.suggest_int("bagging_freq", 1, 10),
"reg_alpha": trial.suggest_float("reg_alpha", 1e-8, 10.0, log=True),
"reg_lambda": trial.suggest_float("reg_lambda", 1e-8, 10.0, log=True),
"n_estimators": trial.suggest_categorical("n_estimators", list(n_estimators_choices)),
}
)
params.setdefault("objective", "binary")
params.setdefault("metric", "auc")
params.setdefault("verbosity", -1)
params.setdefault("boosting_type", "gbdt")
params.setdefault("n_jobs", -1)
return params
def train_lightgbm_task(
X_train: np.ndarray,
y_train: np.ndarray,
X_val: np.ndarray,
y_val: np.ndarray,
base_params: Dict,
n_estimators_choices: Sequence[int],
early_stopping_rounds: int,
n_trials: int,
seed: int,
) -> Optional[TaskTrainingOutput]:
if len(np.unique(y_train)) < 2 or len(np.unique(y_val)) < 2:
return None
def objective(trial: optuna.Trial) -> float:
params = _sample_hyperparams(trial, base_params, n_estimators_choices)
params["random_state"] = seed
model = lgb.LGBMClassifier(**params)
model.fit(
X_train,
y_train,
eval_set=[(X_val, y_val)],
eval_metric="auc",
callbacks=[
lgb.early_stopping(
early_stopping_rounds,
first_metric_only=True,
verbose=False,
)
],
)
best_iter = getattr(model, "best_iteration_", params["n_estimators"])
preds = model.predict_proba(X_val, num_iteration=best_iter)[:, 1]
return float(roc_auc_score(y_val, preds))
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=n_trials, show_progress_bar=False)
best_params = _sample_hyperparams(study.best_trial, base_params, n_estimators_choices)
best_params["random_state"] = seed
final_model = lgb.LGBMClassifier(**best_params)
final_model.fit(
X_train,
y_train,
eval_set=[(X_val, y_val)],
eval_metric="auc",
callbacks=[
lgb.early_stopping(
early_stopping_rounds,
first_metric_only=True,
verbose=False,
)
],
)
best_iteration = getattr(final_model, "best_iteration_", best_params["n_estimators"])
val_preds = final_model.predict_proba(X_val, num_iteration=best_iteration)[:, 1]
val_auc = roc_auc_score(y_val, val_preds)
return TaskTrainingOutput(
model=final_model,
val_auc=float(val_auc),
best_iteration=int(best_iteration),
best_params=best_params,
)
def save_stage_metrics(metrics: Dict, path: Path):
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as f:
json.dump(metrics, f, indent=2)
def train_stage_one_models(
train_features: np.ndarray,
val_features: Optional[np.ndarray],
train_df: pd.DataFrame,
val_df: Optional[pd.DataFrame],
config: Dict,
checkpoint_dir: Path,
target_names: Sequence[str] = TARGET_NAMES,
) -> Dict:
stage_dir = checkpoint_dir / "stage1"
stage_dir.mkdir(parents=True, exist_ok=True)
training_cfg = config.get("training", {})
base_params = training_cfg.get("lightgbm_params", {})
n_trials = training_cfg.get("optuna_trials", 40)
n_estimators_choices = resolve_n_estimators(training_cfg)
early_stopping = training_cfg.get("early_stopping_rounds", 100)
seed = config.get("seed", 42)
task_list = list(target_names)
n_train = len(train_df)
n_tasks = len(task_list)
train_preds = np.full((n_train, n_tasks), 0.5, dtype=np.float32)
val_preds = (
np.full((len(val_df), n_tasks), 0.5, dtype=np.float32)
if val_df is not None and val_features is not None
else None
)
metrics: Dict[str, Dict] = {}
params_dump: Dict[str, Dict] = {}
with tqdm(task_list, desc="Stage 1", unit="task") as progress_bar:
for task_idx, task_name in enumerate(progress_bar):
progress_bar.set_postfix(task=task_name)
train_mask = train_df[task_name].notna().values
if val_df is None or val_features is None:
metrics[task_name] = {"status": "skipped", "reason": "missing validation split"}
continue
val_mask = val_df[task_name].notna().values
if train_mask.sum() < 2 or val_mask.sum() < 2:
metrics[task_name] = {"status": "skipped", "reason": "insufficient labeled data"}
continue
X_train_task = train_features[train_mask]
y_train_task = train_df.loc[train_mask, task_name].astype(float).values
X_val_task = val_features[val_mask]
y_val_task = val_df.loc[val_mask, task_name].astype(float).values
if len(np.unique(y_train_task)) < 2 or len(np.unique(y_val_task)) < 2:
metrics[task_name] = {"status": "skipped", "reason": "single-class labels"}
continue
task_result = train_lightgbm_task(
X_train_task,
y_train_task,
X_val_task,
y_val_task,
base_params=base_params,
n_estimators_choices=n_estimators_choices,
early_stopping_rounds=early_stopping,
n_trials=n_trials,
seed=seed,
)
if task_result is None:
metrics[task_name] = {"status": "skipped", "reason": "training failed"}
continue
model = task_result.model
best_iter = task_result.best_iteration
model_path = stage_dir / f"{task_name}.pkl"
joblib.dump(model, model_path)
params_dump[task_name] = {
**task_result.best_params,
"best_iteration": best_iter,
"val_auc": task_result.val_auc,
}
full_train_preds = model.predict_proba(
train_features,
num_iteration=best_iter,
)[:, 1]
train_preds[:, task_idx] = full_train_preds.astype(np.float32)
if val_preds is not None:
full_val_preds = model.predict_proba(
val_features,
num_iteration=best_iter,
)[:, 1]
val_preds[:, task_idx] = full_val_preds.astype(np.float32)
metrics[task_name] = {
"val_auc": task_result.val_auc,
"n_train_samples": int(train_mask.sum()),
"n_val_samples": int(val_mask.sum()),
}
save_stage_metrics(metrics, checkpoint_dir / "metrics_stage1.json")
params_path = checkpoint_dir / "stage1_params.json"
with params_path.open("w", encoding="utf-8") as f:
json.dump(params_dump, f, indent=2)
return {
"train_full": train_preds,
"val_full": val_preds,
"metrics": metrics,
}