Spaces:
Sleeping
Sleeping
File size: 3,828 Bytes
94b1553 759324e 94b1553 759324e 94b1553 759324e 94b1553 759324e 94b1553 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
from __future__ import annotations
from pathlib import Path
from typing import Dict, Optional, Sequence
import joblib
import numpy as np
import pandas as pd
from tqdm import tqdm
from .constants import TARGET_NAMES
from .lightgbm_trainer import resolve_n_estimators, save_stage_metrics, train_lightgbm_task
def _build_augmented_matrix(base_features: np.ndarray, prediction_matrix: np.ndarray, target_idx: int) -> np.ndarray:
mask = np.ones(prediction_matrix.shape[1], dtype=bool)
mask[target_idx] = False
return np.concatenate([base_features, prediction_matrix[:, mask]], axis=1)
def train_stage_two_models(
train_features: np.ndarray,
val_features: Optional[np.ndarray],
train_df: pd.DataFrame,
val_df: Optional[pd.DataFrame],
config: Dict,
checkpoint_dir: Path,
stage1_train_preds: np.ndarray,
stage1_val_preds: Optional[np.ndarray],
target_names: Sequence[str] = TARGET_NAMES,
) -> Dict:
training_cfg = config.get("training", {})
base_params = training_cfg.get("lightgbm_params", {})
n_trials = training_cfg.get("optuna_trials", 40)
n_estimators_choices = resolve_n_estimators(training_cfg)
early_stopping = training_cfg.get("early_stopping_rounds", 100)
seed = config.get("seed", 42)
stage_dir = checkpoint_dir / "stage2"
stage_dir.mkdir(parents=True, exist_ok=True)
metrics: Dict[str, Dict] = {}
task_list = list(target_names)
with tqdm(task_list, desc="Stage 2", unit="task") as progress_bar:
for task_idx, task_name in enumerate(progress_bar):
progress_bar.set_postfix(task=task_name)
mask = train_df[task_name].notna().values
if mask.sum() == 0:
metrics[task_name] = {"status": "skipped", "reason": "no labels"}
continue
augmented_train_matrix = _build_augmented_matrix(
train_features[mask],
stage1_train_preds[mask],
task_idx,
)
y_train = train_df.loc[mask, task_name].astype(float).values
if (
val_features is None
or val_df is None
or stage1_val_preds is None
or val_df[task_name].notna().sum() < 2
):
metrics[task_name] = {"status": "skipped", "reason": "missing validation data"}
continue
val_mask = val_df[task_name].notna().values
augmented_val_matrix = _build_augmented_matrix(
val_features[val_mask],
stage1_val_preds[val_mask],
task_idx,
)
y_val = val_df.loc[val_mask, task_name].astype(float).values
if len(np.unique(y_val)) < 2 or len(np.unique(y_train)) < 2:
metrics[task_name] = {"status": "skipped", "reason": "single-class labels"}
continue
task_result = train_lightgbm_task(
augmented_train_matrix,
y_train,
augmented_val_matrix,
y_val,
base_params=base_params,
n_estimators_choices=n_estimators_choices,
early_stopping_rounds=early_stopping,
n_trials=n_trials,
seed=seed,
)
if task_result is None:
metrics[task_name] = {"status": "skipped", "reason": "training failed"}
continue
model_path = stage_dir / f"{task_name}.pkl"
joblib.dump(task_result.model, model_path)
metrics[task_name] = {
"val_auc": task_result.val_auc,
"best_iteration": int(task_result.best_iteration),
}
save_stage_metrics(metrics, checkpoint_dir / "metrics_stage2.json")
return {"metrics": metrics}
|