from __future__ import annotations from pathlib import Path from typing import Dict, Optional, Sequence import joblib import numpy as np import pandas as pd from tqdm import tqdm from .constants import TARGET_NAMES from .lightgbm_trainer import resolve_n_estimators, save_stage_metrics, train_lightgbm_task def _build_augmented_matrix(base_features: np.ndarray, prediction_matrix: np.ndarray, target_idx: int) -> np.ndarray: mask = np.ones(prediction_matrix.shape[1], dtype=bool) mask[target_idx] = False return np.concatenate([base_features, prediction_matrix[:, mask]], axis=1) def train_stage_two_models( train_features: np.ndarray, val_features: Optional[np.ndarray], train_df: pd.DataFrame, val_df: Optional[pd.DataFrame], config: Dict, checkpoint_dir: Path, stage1_train_preds: np.ndarray, stage1_val_preds: Optional[np.ndarray], target_names: Sequence[str] = TARGET_NAMES, ) -> Dict: training_cfg = config.get("training", {}) base_params = training_cfg.get("lightgbm_params", {}) n_trials = training_cfg.get("optuna_trials", 40) n_estimators_choices = resolve_n_estimators(training_cfg) early_stopping = training_cfg.get("early_stopping_rounds", 100) seed = config.get("seed", 42) stage_dir = checkpoint_dir / "stage2" stage_dir.mkdir(parents=True, exist_ok=True) metrics: Dict[str, Dict] = {} task_list = list(target_names) with tqdm(task_list, desc="Stage 2", unit="task") as progress_bar: for task_idx, task_name in enumerate(progress_bar): progress_bar.set_postfix(task=task_name) mask = train_df[task_name].notna().values if mask.sum() == 0: metrics[task_name] = {"status": "skipped", "reason": "no labels"} continue augmented_train_matrix = _build_augmented_matrix( train_features[mask], stage1_train_preds[mask], task_idx, ) y_train = train_df.loc[mask, task_name].astype(float).values if ( val_features is None or val_df is None or stage1_val_preds is None or val_df[task_name].notna().sum() < 2 ): metrics[task_name] = {"status": "skipped", "reason": "missing validation data"} continue val_mask = val_df[task_name].notna().values augmented_val_matrix = _build_augmented_matrix( val_features[val_mask], stage1_val_preds[val_mask], task_idx, ) y_val = val_df.loc[val_mask, task_name].astype(float).values if len(np.unique(y_val)) < 2 or len(np.unique(y_train)) < 2: metrics[task_name] = {"status": "skipped", "reason": "single-class labels"} continue task_result = train_lightgbm_task( augmented_train_matrix, y_train, augmented_val_matrix, y_val, base_params=base_params, n_estimators_choices=n_estimators_choices, early_stopping_rounds=early_stopping, n_trials=n_trials, seed=seed, ) if task_result is None: metrics[task_name] = {"status": "skipped", "reason": "training failed"} continue model_path = stage_dir / f"{task_name}.pkl" joblib.dump(task_result.model, model_path) metrics[task_name] = { "val_auc": task_result.val_auc, "best_iteration": int(task_result.best_iteration), } save_stage_metrics(metrics, checkpoint_dir / "metrics_stage2.json") return {"metrics": metrics}