File size: 3,828 Bytes
94b1553
 
 
 
 
 
 
 
759324e
94b1553
 
759324e
94b1553
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
759324e
94b1553
 
 
 
 
 
 
759324e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94b1553
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from __future__ import annotations

from pathlib import Path
from typing import Dict, Optional, Sequence

import joblib
import numpy as np
import pandas as pd
from tqdm import tqdm

from .constants import TARGET_NAMES
from .lightgbm_trainer import resolve_n_estimators, save_stage_metrics, train_lightgbm_task


def _build_augmented_matrix(base_features: np.ndarray, prediction_matrix: np.ndarray, target_idx: int) -> np.ndarray:
    mask = np.ones(prediction_matrix.shape[1], dtype=bool)
    mask[target_idx] = False
    return np.concatenate([base_features, prediction_matrix[:, mask]], axis=1)


def train_stage_two_models(
    train_features: np.ndarray,
    val_features: Optional[np.ndarray],
    train_df: pd.DataFrame,
    val_df: Optional[pd.DataFrame],
    config: Dict,
    checkpoint_dir: Path,
    stage1_train_preds: np.ndarray,
    stage1_val_preds: Optional[np.ndarray],
    target_names: Sequence[str] = TARGET_NAMES,
) -> Dict:
    training_cfg = config.get("training", {})
    base_params = training_cfg.get("lightgbm_params", {})
    n_trials = training_cfg.get("optuna_trials", 40)
    n_estimators_choices = resolve_n_estimators(training_cfg)
    early_stopping = training_cfg.get("early_stopping_rounds", 100)
    seed = config.get("seed", 42)

    stage_dir = checkpoint_dir / "stage2"
    stage_dir.mkdir(parents=True, exist_ok=True)

    metrics: Dict[str, Dict] = {}
    task_list = list(target_names)

    with tqdm(task_list, desc="Stage 2", unit="task") as progress_bar:
        for task_idx, task_name in enumerate(progress_bar):
            progress_bar.set_postfix(task=task_name)
            mask = train_df[task_name].notna().values
            if mask.sum() == 0:
                metrics[task_name] = {"status": "skipped", "reason": "no labels"}
                continue

            augmented_train_matrix = _build_augmented_matrix(
                train_features[mask],
                stage1_train_preds[mask],
                task_idx,
            )
            y_train = train_df.loc[mask, task_name].astype(float).values

            if (
                val_features is None
                or val_df is None
                or stage1_val_preds is None
                or val_df[task_name].notna().sum() < 2
            ):
                metrics[task_name] = {"status": "skipped", "reason": "missing validation data"}
                continue

            val_mask = val_df[task_name].notna().values
            augmented_val_matrix = _build_augmented_matrix(
                val_features[val_mask],
                stage1_val_preds[val_mask],
                task_idx,
            )
            y_val = val_df.loc[val_mask, task_name].astype(float).values

            if len(np.unique(y_val)) < 2 or len(np.unique(y_train)) < 2:
                metrics[task_name] = {"status": "skipped", "reason": "single-class labels"}
                continue

            task_result = train_lightgbm_task(
                augmented_train_matrix,
                y_train,
                augmented_val_matrix,
                y_val,
                base_params=base_params,
                n_estimators_choices=n_estimators_choices,
                early_stopping_rounds=early_stopping,
                n_trials=n_trials,
                seed=seed,
            )

            if task_result is None:
                metrics[task_name] = {"status": "skipped", "reason": "training failed"}
                continue

            model_path = stage_dir / f"{task_name}.pkl"
            joblib.dump(task_result.model, model_path)

            metrics[task_name] = {
                "val_auc": task_result.val_auc,
                "best_iteration": int(task_result.best_iteration),
            }

    save_stage_metrics(metrics, checkpoint_dir / "metrics_stage2.json")
    return {"metrics": metrics}