File size: 3,879 Bytes
94b1553
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from __future__ import annotations

import json
import os
from pathlib import Path
from typing import Dict

import numpy as np
from dotenv import load_dotenv

from src.constants import TARGET_NAMES
from src.features import FingerprintFeaturizer
from src.lightgbm_trainer import train_stage_one_models
from src.preprocess import load_tox21_dataset
from src.seed import set_seed
from src.stage_two import train_stage_two_models


def _default_checkpoint_dir(config: Dict) -> Path:
    checkpoint_cfg = config.get("output", {})
    checkpoint_dir = checkpoint_cfg.get("checkpoint_dir", "./checkpoints")
    path = Path(checkpoint_dir)
    path.mkdir(parents=True, exist_ok=True)
    return path


def train(config: Dict):
    load_dotenv()
    set_seed(config.get("seed", 42))
    token = os.getenv("TOKEN")

    dataset_cfg = config.get("dataset", {})
    dataset_name = dataset_cfg.get("name", "ml-jku/tox21")
    splits = load_tox21_dataset(token, dataset_name)

    if "train" not in splits or "validation" not in splits:
        raise ValueError("Dataset must provide 'train' and 'validation' splits.")

    featurizer = FingerprintFeaturizer(config.get("features", {}))
    train_df, train_features = featurizer.featurize_dataframe(splits["train"], "train")
    val_df, val_features = featurizer.featurize_dataframe(splits["validation"], "validation")

    checkpoint_dir = _default_checkpoint_dir(config)
    cache_dir = checkpoint_dir / "cache"
    cache_dir.mkdir(parents=True, exist_ok=True)

    print("==== Stage 1: Training baseline LightGBM models ====")
    stage1_artifacts = train_stage_one_models(
        train_features,
        val_features,
        train_df,
        val_df,
        config,
        checkpoint_dir,
        target_names=TARGET_NAMES,
    )

    stage1_train_full = stage1_artifacts["train_full"]
    stage1_val_full = stage1_artifacts["val_full"]

    np.savez(
        cache_dir / "stage1_train_predictions.npz",
        full=stage1_train_full,
        target_names=np.array(TARGET_NAMES, dtype=object),
    )
    if stage1_val_full is not None:
        np.savez(
            cache_dir / "stage1_validation_predictions.npz",
            full=stage1_val_full,
            target_names=np.array(TARGET_NAMES, dtype=object),
        )

    stage2_metrics = None
    multitask_cfg = config.get("multitask", {"enabled": False})
    if multitask_cfg.get("enabled", False):
        print("==== Stage 2: Training multitask-augmented LightGBM models ====")
        stage2_artifacts = train_stage_two_models(
            train_features,
            val_features,
            train_df,
            val_df,
            config,
            checkpoint_dir,
            stage1_train_full,
            stage1_val_full,
            target_names=TARGET_NAMES,
        )
        stage2_metrics = stage2_artifacts["metrics"]

    stage2_entry = {
        "enabled": bool(multitask_cfg.get("enabled", False)),
        "model_dir": str(checkpoint_dir / "stage2") if stage2_metrics is not None else None,
        "metrics": str(checkpoint_dir / "metrics_stage2.json") if stage2_metrics is not None else None,
    }

    manifest = {
        "feature_config": config.get("features", {}),
        "target_names": TARGET_NAMES,
        "dataset": dataset_cfg,
        "stage1": {
            "model_dir": str(checkpoint_dir / "stage1"),
            "metrics": str((checkpoint_dir / "metrics_stage1.json")),
        },
        "stage2": stage2_entry,
        "multitask": multitask_cfg,
        "seed": config.get("seed", 42),
    }

    manifest_path = checkpoint_dir / "training_manifest.json"
    with manifest_path.open("w", encoding="utf-8") as f:
        json.dump(manifest, f, indent=2)

    print("Training complete.")


if __name__ == "__main__":
    with open("./config/config.json", "r", encoding="utf-8") as f:
        config = json.load(f)
    train(config)