MultiTaskTox / src /stage_two.py
Maximilian Schuh
Added new files and updated learning
759324e
from __future__ import annotations
from pathlib import Path
from typing import Dict, Optional, Sequence
import joblib
import numpy as np
import pandas as pd
from tqdm import tqdm
from .constants import TARGET_NAMES
from .lightgbm_trainer import resolve_n_estimators, save_stage_metrics, train_lightgbm_task
def _build_augmented_matrix(base_features: np.ndarray, prediction_matrix: np.ndarray, target_idx: int) -> np.ndarray:
mask = np.ones(prediction_matrix.shape[1], dtype=bool)
mask[target_idx] = False
return np.concatenate([base_features, prediction_matrix[:, mask]], axis=1)
def train_stage_two_models(
train_features: np.ndarray,
val_features: Optional[np.ndarray],
train_df: pd.DataFrame,
val_df: Optional[pd.DataFrame],
config: Dict,
checkpoint_dir: Path,
stage1_train_preds: np.ndarray,
stage1_val_preds: Optional[np.ndarray],
target_names: Sequence[str] = TARGET_NAMES,
) -> Dict:
training_cfg = config.get("training", {})
base_params = training_cfg.get("lightgbm_params", {})
n_trials = training_cfg.get("optuna_trials", 40)
n_estimators_choices = resolve_n_estimators(training_cfg)
early_stopping = training_cfg.get("early_stopping_rounds", 100)
seed = config.get("seed", 42)
stage_dir = checkpoint_dir / "stage2"
stage_dir.mkdir(parents=True, exist_ok=True)
metrics: Dict[str, Dict] = {}
task_list = list(target_names)
with tqdm(task_list, desc="Stage 2", unit="task") as progress_bar:
for task_idx, task_name in enumerate(progress_bar):
progress_bar.set_postfix(task=task_name)
mask = train_df[task_name].notna().values
if mask.sum() == 0:
metrics[task_name] = {"status": "skipped", "reason": "no labels"}
continue
augmented_train_matrix = _build_augmented_matrix(
train_features[mask],
stage1_train_preds[mask],
task_idx,
)
y_train = train_df.loc[mask, task_name].astype(float).values
if (
val_features is None
or val_df is None
or stage1_val_preds is None
or val_df[task_name].notna().sum() < 2
):
metrics[task_name] = {"status": "skipped", "reason": "missing validation data"}
continue
val_mask = val_df[task_name].notna().values
augmented_val_matrix = _build_augmented_matrix(
val_features[val_mask],
stage1_val_preds[val_mask],
task_idx,
)
y_val = val_df.loc[val_mask, task_name].astype(float).values
if len(np.unique(y_val)) < 2 or len(np.unique(y_train)) < 2:
metrics[task_name] = {"status": "skipped", "reason": "single-class labels"}
continue
task_result = train_lightgbm_task(
augmented_train_matrix,
y_train,
augmented_val_matrix,
y_val,
base_params=base_params,
n_estimators_choices=n_estimators_choices,
early_stopping_rounds=early_stopping,
n_trials=n_trials,
seed=seed,
)
if task_result is None:
metrics[task_name] = {"status": "skipped", "reason": "training failed"}
continue
model_path = stage_dir / f"{task_name}.pkl"
joblib.dump(task_result.model, model_path)
metrics[task_name] = {
"val_auc": task_result.val_auc,
"best_iteration": int(task_result.best_iteration),
}
save_stage_metrics(metrics, checkpoint_dir / "metrics_stage2.json")
return {"metrics": metrics}