Maximilian Schuh commited on
Commit
759324e
·
1 Parent(s): e1aa0ed

Added new files and updated learning

Browse files
README.md CHANGED
@@ -20,30 +20,6 @@ MultiTaskTox is a two-stage Gradient Boosting workflow purpose-built for the [To
20
  - **Multitask enhancement** – stage two augments the fingerprint vector with the predictions of the other tasks, capturing label correlations without building a fully joint model.
21
  - **Leaderboard-ready interface** – `train.py` produces checkpoints and metadata under `checkpoints/`, while `predict.py` exposes the required `predict(smiles_list)` signature.
22
 
23
- ## Installation
24
-
25
- ```bash
26
- git clone https://huggingface.co/spaces/ml-jku/tox21_gin_classifier
27
- cd tox21_gin_classifier
28
- python -m venv .venv && source .venv/bin/activate
29
- pip install --upgrade pip
30
- pip install -r requirements.txt
31
- ```
32
-
33
- The requirements include RDKit, LightGBM, Optuna, and the MAP4 fingerprint package so you can switch feature types via the config.
34
-
35
- ## Training
36
-
37
- 1. Create a `.env` file (all Hugging Face Spaces support secrets) with your dataset token:
38
- ```
39
- TOKEN=hf_xxx
40
- ```
41
- 2. Adjust `config/config.json` if needed (fingerprint type, Optuna trial count, etc.).
42
- 3. Run:
43
- ```bash
44
- python train.py
45
- ```
46
-
47
  ### What `train.py` does
48
 
49
  1. Loads the predefined `train` and `validation` splits from the Tox21 dataset.
@@ -88,7 +64,7 @@ The function:
88
  },
89
  "training": {
90
  "optuna_trials": 40,
91
- "boosting_rounds": 1500,
92
  "early_stopping_rounds": 100,
93
  "lightgbm_params": {
94
  "objective": "binary",
@@ -104,6 +80,7 @@ The function:
104
  - Switch `features.type` to `"map4"` to use MAP4 fingerprints (installed by default).
105
  - Disable multitask behavior by setting `"multitask": {"enabled": false}`.
106
  - Increase `optuna_trials` for a more exhaustive search if compute allows.
 
107
 
108
  ## Repository Layout
109
 
@@ -116,11 +93,3 @@ The function:
116
  - `src/constants.py`, `src/seed.py` – shared utilities.
117
  - `docs/proposed_lightgbm_framework.md` – detailed design notes for the workflow.
118
  - `checkpoints/` – default output directory containing models, metrics, caches, and the training manifest used at inference time.
119
-
120
- ## Tips
121
-
122
- - Training relies on the `TOKEN` environment variable to access the Tox21 dataset on Hugging Face. Locally you can omit it if the dataset is public for your account.
123
- - MAP4 fingerprints are more expensive to compute; enable the cache directory to avoid recomputation across runs.
124
- - Use the saved metrics files to compare stage-one vs. stage-two AUCs and to trace which configuration produced a set of checkpoints.
125
-
126
- Happy modeling! If you extend MultiTaskTox (new fingerprints, alternative learners, etc.), keep the `predict(smiles)` contract intact so your Space remains leaderboard compatible.
 
20
  - **Multitask enhancement** – stage two augments the fingerprint vector with the predictions of the other tasks, capturing label correlations without building a fully joint model.
21
  - **Leaderboard-ready interface** – `train.py` produces checkpoints and metadata under `checkpoints/`, while `predict.py` exposes the required `predict(smiles_list)` signature.
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  ### What `train.py` does
24
 
25
  1. Loads the predefined `train` and `validation` splits from the Tox21 dataset.
 
64
  },
65
  "training": {
66
  "optuna_trials": 40,
67
+ "n_estimators": [50, 500, 1000],
68
  "early_stopping_rounds": 100,
69
  "lightgbm_params": {
70
  "objective": "binary",
 
80
  - Switch `features.type` to `"map4"` to use MAP4 fingerprints (installed by default).
81
  - Disable multitask behavior by setting `"multitask": {"enabled": false}`.
82
  - Increase `optuna_trials` for a more exhaustive search if compute allows.
83
+ - Set `training.n_estimators` to either a single integer or a list of candidate values (default `[50, 500, 1000]`) to control the Optuna search space for the `n_estimators` hyperparameter.
84
 
85
  ## Repository Layout
86
 
 
93
  - `src/constants.py`, `src/seed.py` – shared utilities.
94
  - `docs/proposed_lightgbm_framework.md` – detailed design notes for the workflow.
95
  - `checkpoints/` – default output directory containing models, metrics, caches, and the training manifest used at inference time.
 
 
 
 
 
 
 
 
checkpoints/cache/stage1_train_predictions.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:019864a99cd005c40bc979da6ec6594398a02fdeaaedb943ea336b7603efab49
3
+ size 565279
checkpoints/cache/stage1_validation_predictions.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eab5701f117dbe4f1bd0755d870e1f5518e9fb99eae505c459af3576534fb843
3
+ size 15007
checkpoints/cache/train_ecfp.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dfc0c773eb523d256b84fec1cc314f290bcc5b9288073aff576712437450bf5a
3
+ size 48726858
checkpoints/cache/validation_ecfp.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52f576b159c6e09d1ca22f0e260256cd46deacb137dfe77f127c756117671830
3
+ size 1224294
checkpoints/metrics_stage1.json ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "NR-AhR": {
3
+ "val_auc": 0.8789764868603043,
4
+ "n_train_samples": 8165,
5
+ "n_val_samples": 271
6
+ },
7
+ "NR-AR": {
8
+ "val_auc": 0.8547453703703702,
9
+ "n_train_samples": 9358,
10
+ "n_val_samples": 291
11
+ },
12
+ "NR-AR-LBD": {
13
+ "val_auc": 0.9717741935483871,
14
+ "n_train_samples": 8595,
15
+ "n_val_samples": 252
16
+ },
17
+ "NR-Aromatase": {
18
+ "val_auc": 0.8483560090702948,
19
+ "n_train_samples": 7222,
20
+ "n_val_samples": 214
21
+ },
22
+ "NR-ER": {
23
+ "val_auc": 0.7974683544303797,
24
+ "n_train_samples": 7694,
25
+ "n_val_samples": 264
26
+ },
27
+ "NR-ER-LBD": {
28
+ "val_auc": 0.8347826086956521,
29
+ "n_train_samples": 8749,
30
+ "n_val_samples": 286
31
+ },
32
+ "NR-PPAR-gamma": {
33
+ "val_auc": 0.8077025232403718,
34
+ "n_train_samples": 8180,
35
+ "n_val_samples": 266
36
+ },
37
+ "SR-ARE": {
38
+ "val_auc": 0.8194921070693205,
39
+ "n_train_samples": 7165,
40
+ "n_val_samples": 233
41
+ },
42
+ "SR-ATAD5": {
43
+ "val_auc": 0.8749593495934959,
44
+ "n_train_samples": 9087,
45
+ "n_val_samples": 271
46
+ },
47
+ "SR-HSE": {
48
+ "val_auc": 0.92421875,
49
+ "n_train_samples": 8147,
50
+ "n_val_samples": 266
51
+ },
52
+ "SR-MMP": {
53
+ "val_auc": 0.9280613594287226,
54
+ "n_train_samples": 7317,
55
+ "n_val_samples": 237
56
+ },
57
+ "SR-p53": {
58
+ "val_auc": 0.8038690476190476,
59
+ "n_train_samples": 8630,
60
+ "n_val_samples": 268
61
+ }
62
+ }
checkpoints/metrics_stage2.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "NR-AhR": {
3
+ "val_auc": 0.8740663900414938,
4
+ "best_iteration": 5
5
+ },
6
+ "NR-AR": {
7
+ "val_auc": 0.8854166666666666,
8
+ "best_iteration": 204
9
+ },
10
+ "NR-AR-LBD": {
11
+ "val_auc": 0.9828629032258065,
12
+ "best_iteration": 23
13
+ },
14
+ "NR-Aromatase": {
15
+ "val_auc": 0.8572845804988662,
16
+ "best_iteration": 8
17
+ },
18
+ "NR-ER": {
19
+ "val_auc": 0.8123144241287701,
20
+ "best_iteration": 2
21
+ },
22
+ "NR-ER-LBD": {
23
+ "val_auc": 0.857608695652174,
24
+ "best_iteration": 4
25
+ },
26
+ "NR-PPAR-gamma": {
27
+ "val_auc": 0.8739707835325365,
28
+ "best_iteration": 13
29
+ },
30
+ "SR-ARE": {
31
+ "val_auc": 0.857698467169984,
32
+ "best_iteration": 4
33
+ },
34
+ "SR-ATAD5": {
35
+ "val_auc": 0.8530081300813008,
36
+ "best_iteration": 145
37
+ },
38
+ "SR-HSE": {
39
+ "val_auc": 0.9201171875,
40
+ "best_iteration": 2
41
+ },
42
+ "SR-MMP": {
43
+ "val_auc": 0.9312351229833377,
44
+ "best_iteration": 395
45
+ },
46
+ "SR-p53": {
47
+ "val_auc": 0.8722470238095239,
48
+ "best_iteration": 3
49
+ }
50
+ }
checkpoints/stage1/NR-AR-LBD.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d1b4030d07f948b2b549fc6426587a05c269648fbc40131d6e99241c8972d28
3
+ size 627492
checkpoints/stage1/NR-AR.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:901eb451790a2607cc3e3c6eb6328f947899b557ec9887adc2481aea5a95433e
3
+ size 24772
checkpoints/stage1/NR-AhR.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d2afa34f2143ab1d38247ae82d9c859ea74c37799134c571e30a3ec3e0cfe10
3
+ size 1001188
checkpoints/stage1/NR-Aromatase.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf3aef5acc8411f33bbfe184a45ecb8b292eee93260f1b688a82e2ba26e83ae6
3
+ size 428964
checkpoints/stage1/NR-ER-LBD.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d42aa1195696a8bab89cf0a6ad5914cbfb198d8e946d316dbec3df22621181b
3
+ size 44628
checkpoints/stage1/NR-ER.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8be3a13fc68d2c96b827dab8771026de3944e400150b3b173cc77aa23eec328e
3
+ size 113444
checkpoints/stage1/NR-PPAR-gamma.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6cbc9dce92de09a74df3df581de742044f1ae7d522a5018d8006474d089768d
3
+ size 58420
checkpoints/stage1/SR-ARE.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:47fc92884f6ed03d8738f58b809a6c616d13dc7322249008941fbd2506f834bf
3
+ size 392756
checkpoints/stage1/SR-ATAD5.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d00260b6422786faed1efafb99deca01be4021bc7bf4b6da8e807beacac6e14a
3
+ size 375092
checkpoints/stage1/SR-HSE.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7b67c8aba900b7036a4d41e1171d3dabee2423fb1a08a26accf5b3d533065da
3
+ size 237860
checkpoints/stage1/SR-MMP.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:528a7efd2fe9748037da657d88959821f994dfbe35cdbec0a4c415bd3a9ed898
3
+ size 606164
checkpoints/stage1/SR-p53.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c2a2dceb70d39f5d72cf68f104be68c763dcf25da33b41bd45abcbcc2e07906
3
+ size 1101540
checkpoints/stage1_params.json ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "NR-AhR": {
3
+ "objective": "binary",
4
+ "metric": "auc",
5
+ "verbosity": -1,
6
+ "learning_rate": 0.04596965586436891,
7
+ "num_leaves": 26,
8
+ "max_depth": 11,
9
+ "min_child_samples": 14,
10
+ "feature_fraction": 0.9910514912963474,
11
+ "bagging_fraction": 0.5143055600954589,
12
+ "bagging_freq": 10,
13
+ "reg_alpha": 1.514871105815587e-07,
14
+ "reg_lambda": 1.0837426184192575e-05,
15
+ "n_estimators": 500,
16
+ "boosting_type": "gbdt",
17
+ "n_jobs": -1,
18
+ "random_state": 42,
19
+ "best_iteration": 321,
20
+ "val_auc": 0.8789764868603043
21
+ },
22
+ "NR-AR": {
23
+ "objective": "binary",
24
+ "metric": "auc",
25
+ "verbosity": -1,
26
+ "learning_rate": 0.17083936308641903,
27
+ "num_leaves": 80,
28
+ "max_depth": 3,
29
+ "min_child_samples": 91,
30
+ "feature_fraction": 0.9278238464266182,
31
+ "bagging_fraction": 0.6395695127332895,
32
+ "bagging_freq": 7,
33
+ "reg_alpha": 5.1826386659948485,
34
+ "reg_lambda": 0.1806421496400551,
35
+ "n_estimators": 500,
36
+ "boosting_type": "gbdt",
37
+ "n_jobs": -1,
38
+ "random_state": 42,
39
+ "best_iteration": 1,
40
+ "val_auc": 0.8547453703703702
41
+ },
42
+ "NR-AR-LBD": {
43
+ "objective": "binary",
44
+ "metric": "auc",
45
+ "verbosity": -1,
46
+ "learning_rate": 0.1992149524787967,
47
+ "num_leaves": 180,
48
+ "max_depth": 12,
49
+ "min_child_samples": 49,
50
+ "feature_fraction": 0.5650870342995183,
51
+ "bagging_fraction": 0.5114405817018007,
52
+ "bagging_freq": 2,
53
+ "reg_alpha": 0.00012622761264931666,
54
+ "reg_lambda": 1.2770493065015112e-08,
55
+ "n_estimators": 500,
56
+ "boosting_type": "gbdt",
57
+ "n_jobs": -1,
58
+ "random_state": 42,
59
+ "best_iteration": 210,
60
+ "val_auc": 0.9717741935483871
61
+ },
62
+ "NR-Aromatase": {
63
+ "objective": "binary",
64
+ "metric": "auc",
65
+ "verbosity": -1,
66
+ "learning_rate": 0.07977276592360394,
67
+ "num_leaves": 225,
68
+ "max_depth": 10,
69
+ "min_child_samples": 10,
70
+ "feature_fraction": 0.9900155440745377,
71
+ "bagging_fraction": 0.8119652252471269,
72
+ "bagging_freq": 2,
73
+ "reg_alpha": 2.594339538659424e-06,
74
+ "reg_lambda": 5.04193406028159e-07,
75
+ "n_estimators": 1000,
76
+ "boosting_type": "gbdt",
77
+ "n_jobs": -1,
78
+ "random_state": 42,
79
+ "best_iteration": 62,
80
+ "val_auc": 0.8483560090702948
81
+ },
82
+ "NR-ER": {
83
+ "objective": "binary",
84
+ "metric": "auc",
85
+ "verbosity": -1,
86
+ "learning_rate": 0.011154410439266932,
87
+ "num_leaves": 193,
88
+ "max_depth": 0,
89
+ "min_child_samples": 13,
90
+ "feature_fraction": 0.976937312627059,
91
+ "bagging_fraction": 0.8101945461275061,
92
+ "bagging_freq": 10,
93
+ "reg_alpha": 3.712452767192828e-05,
94
+ "reg_lambda": 2.4101531861104065e-07,
95
+ "n_estimators": 500,
96
+ "boosting_type": "gbdt",
97
+ "n_jobs": -1,
98
+ "random_state": 42,
99
+ "best_iteration": 4,
100
+ "val_auc": 0.7974683544303797
101
+ },
102
+ "NR-ER-LBD": {
103
+ "objective": "binary",
104
+ "metric": "auc",
105
+ "verbosity": -1,
106
+ "learning_rate": 0.002057546388429118,
107
+ "num_leaves": 46,
108
+ "max_depth": 7,
109
+ "min_child_samples": 183,
110
+ "feature_fraction": 0.9180640687080166,
111
+ "bagging_fraction": 0.5266816009973788,
112
+ "bagging_freq": 2,
113
+ "reg_alpha": 0.4206611876821779,
114
+ "reg_lambda": 0.024870165101451694,
115
+ "n_estimators": 500,
116
+ "boosting_type": "gbdt",
117
+ "n_jobs": -1,
118
+ "random_state": 42,
119
+ "best_iteration": 18,
120
+ "val_auc": 0.8347826086956521
121
+ },
122
+ "NR-PPAR-gamma": {
123
+ "objective": "binary",
124
+ "metric": "auc",
125
+ "verbosity": -1,
126
+ "learning_rate": 0.05900664732601199,
127
+ "num_leaves": 235,
128
+ "max_depth": 10,
129
+ "min_child_samples": 73,
130
+ "feature_fraction": 0.9286215997058814,
131
+ "bagging_fraction": 0.6099533117515117,
132
+ "bagging_freq": 8,
133
+ "reg_alpha": 0.0002410679912248092,
134
+ "reg_lambda": 2.2374585644452814e-05,
135
+ "n_estimators": 50,
136
+ "boosting_type": "gbdt",
137
+ "n_jobs": -1,
138
+ "random_state": 42,
139
+ "best_iteration": 16,
140
+ "val_auc": 0.8077025232403718
141
+ },
142
+ "SR-ARE": {
143
+ "objective": "binary",
144
+ "metric": "auc",
145
+ "verbosity": -1,
146
+ "learning_rate": 0.11683900433079997,
147
+ "num_leaves": 25,
148
+ "max_depth": 6,
149
+ "min_child_samples": 121,
150
+ "feature_fraction": 0.9155819290034379,
151
+ "bagging_fraction": 0.6064835300540737,
152
+ "bagging_freq": 6,
153
+ "reg_alpha": 0.031558826962596216,
154
+ "reg_lambda": 0.49750603290125384,
155
+ "n_estimators": 1000,
156
+ "boosting_type": "gbdt",
157
+ "n_jobs": -1,
158
+ "random_state": 42,
159
+ "best_iteration": 313,
160
+ "val_auc": 0.8194921070693205
161
+ },
162
+ "SR-ATAD5": {
163
+ "objective": "binary",
164
+ "metric": "auc",
165
+ "verbosity": -1,
166
+ "learning_rate": 0.1398668165886807,
167
+ "num_leaves": 71,
168
+ "max_depth": 12,
169
+ "min_child_samples": 96,
170
+ "feature_fraction": 0.86899957688569,
171
+ "bagging_fraction": 0.7807007020967096,
172
+ "bagging_freq": 10,
173
+ "reg_alpha": 1.010171177179027e-08,
174
+ "reg_lambda": 2.3747514557444565e-07,
175
+ "n_estimators": 500,
176
+ "boosting_type": "gbdt",
177
+ "n_jobs": -1,
178
+ "random_state": 42,
179
+ "best_iteration": 126,
180
+ "val_auc": 0.8749593495934959
181
+ },
182
+ "SR-HSE": {
183
+ "objective": "binary",
184
+ "metric": "auc",
185
+ "verbosity": -1,
186
+ "learning_rate": 0.15091207817136804,
187
+ "num_leaves": 246,
188
+ "max_depth": 0,
189
+ "min_child_samples": 19,
190
+ "feature_fraction": 0.7867053613239711,
191
+ "bagging_fraction": 0.7013484568271124,
192
+ "bagging_freq": 9,
193
+ "reg_alpha": 0.0006360962863973946,
194
+ "reg_lambda": 6.440534124809522e-05,
195
+ "n_estimators": 500,
196
+ "boosting_type": "gbdt",
197
+ "n_jobs": -1,
198
+ "random_state": 42,
199
+ "best_iteration": 14,
200
+ "val_auc": 0.92421875
201
+ },
202
+ "SR-MMP": {
203
+ "objective": "binary",
204
+ "metric": "auc",
205
+ "verbosity": -1,
206
+ "learning_rate": 0.19884775296113416,
207
+ "num_leaves": 18,
208
+ "max_depth": 11,
209
+ "min_child_samples": 88,
210
+ "feature_fraction": 0.9550517881195121,
211
+ "bagging_fraction": 0.7860517484953123,
212
+ "bagging_freq": 1,
213
+ "reg_alpha": 0.030296963787402084,
214
+ "reg_lambda": 0.8044239737357854,
215
+ "n_estimators": 1000,
216
+ "boosting_type": "gbdt",
217
+ "n_jobs": -1,
218
+ "random_state": 42,
219
+ "best_iteration": 284,
220
+ "val_auc": 0.9280613594287226
221
+ },
222
+ "SR-p53": {
223
+ "objective": "binary",
224
+ "metric": "auc",
225
+ "verbosity": -1,
226
+ "learning_rate": 0.1655358096626077,
227
+ "num_leaves": 85,
228
+ "max_depth": 0,
229
+ "min_child_samples": 87,
230
+ "feature_fraction": 0.8995086723685061,
231
+ "bagging_fraction": 0.9621288945710826,
232
+ "bagging_freq": 6,
233
+ "reg_alpha": 0.001186616750567751,
234
+ "reg_lambda": 0.00030749373152708483,
235
+ "n_estimators": 1000,
236
+ "boosting_type": "gbdt",
237
+ "n_jobs": -1,
238
+ "random_state": 42,
239
+ "best_iteration": 132,
240
+ "val_auc": 0.8038690476190476
241
+ }
242
+ }
checkpoints/stage2/NR-AR-LBD.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b877924ff923463aeb3d9347cc2d64077de0fd1e86c5ebceb1d6e85e7748cbba
3
+ size 71988
checkpoints/stage2/NR-AR.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02e93424e24eb7a4fb704a191ad6d68c05145ab333d434c089e5bb98ff94fb9c
3
+ size 1836148
checkpoints/stage2/NR-AhR.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6681c8042c04542aed22e3c7e1bc756d78cbd90035a1717050202e4841a32c40
3
+ size 61716
checkpoints/stage2/NR-Aromatase.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:58d2b37062490cb61caf6b36ed765493fb17bd1805a2f9242e3c8dda058ed05f
3
+ size 41540
checkpoints/stage2/NR-ER-LBD.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be0e41ee6f1e0463073f4eeb983bfb6f7bcedfba2174e18bab55118feec9ef0c
3
+ size 37092
checkpoints/stage2/NR-ER.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5aaaec49348aef97d610e5b0dd01ea41728474a5c8fc5cee3e961f0b97a28a1f
3
+ size 28884
checkpoints/stage2/NR-PPAR-gamma.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f95337de1dd3dd937e3e15c74ff80b9cdb4d33b445446c1cd1e4044bcc3cb21
3
+ size 74644
checkpoints/stage2/SR-ARE.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9041bf8b713c4ec2793b67a41d2e302b9aa1f615625a2be53c12efdfdec1232c
3
+ size 52196
checkpoints/stage2/SR-ATAD5.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aafc9da7d6dc67e322264cfdd34788b2aa509e91c9de58a3369ce366b89fbec6
3
+ size 485092
checkpoints/stage2/SR-HSE.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:763d559e337eb24b837d7a0cb9a530baa6b3ac8eccd1195e68ba39fbb078d3cb
3
+ size 30452
checkpoints/stage2/SR-MMP.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c8565440bb1a9284032924de804d10993efefd71ce00f92f4014cc11509f51a2
3
+ size 344964
checkpoints/stage2/SR-p53.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f446318ecd2ac93f685664bf68a0866e23a9275d53a32199f6f74557577c2152
3
+ size 30916
checkpoints/training_manifest.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "feature_config": {
3
+ "type": "ecfp",
4
+ "radius": 2,
5
+ "n_bits": 1024,
6
+ "use_counts": false,
7
+ "map4_dim": 1024,
8
+ "cache_dir": "./checkpoints/cache"
9
+ },
10
+ "target_names": [
11
+ "NR-AhR",
12
+ "NR-AR",
13
+ "NR-AR-LBD",
14
+ "NR-Aromatase",
15
+ "NR-ER",
16
+ "NR-ER-LBD",
17
+ "NR-PPAR-gamma",
18
+ "SR-ARE",
19
+ "SR-ATAD5",
20
+ "SR-HSE",
21
+ "SR-MMP",
22
+ "SR-p53"
23
+ ],
24
+ "dataset": {
25
+ "name": "ml-jku/tox21"
26
+ },
27
+ "stage1": {
28
+ "model_dir": "checkpoints/stage1",
29
+ "metrics": "checkpoints/metrics_stage1.json"
30
+ },
31
+ "stage2": {
32
+ "enabled": true,
33
+ "model_dir": "checkpoints/stage2",
34
+ "metrics": "checkpoints/metrics_stage2.json"
35
+ },
36
+ "multitask": {
37
+ "enabled": true,
38
+ "prediction_source": "oof"
39
+ },
40
+ "seed": 42
41
+ }
config/config.json CHANGED
@@ -12,8 +12,8 @@
12
  "cache_dir": "./checkpoints/cache"
13
  },
14
  "training": {
15
- "optuna_trials": 40,
16
- "boosting_rounds": 1500,
17
  "early_stopping_rounds": 100,
18
  "lightgbm_params": {
19
  "objective": "binary",
 
12
  "cache_dir": "./checkpoints/cache"
13
  },
14
  "training": {
15
+ "optuna_trials": 1000,
16
+ "n_estimators": [50, 500, 1000],
17
  "early_stopping_rounds": 100,
18
  "lightgbm_params": {
19
  "objective": "binary",
requirements.txt CHANGED
@@ -11,3 +11,4 @@ lightgbm
11
  optuna
12
  joblib
13
  map4
 
 
11
  optuna
12
  joblib
13
  map4
14
+ tqdm
src/lightgbm_trainer.py CHANGED
@@ -11,6 +11,7 @@ import numpy as np
11
  import optuna
12
  import pandas as pd
13
  from sklearn.metrics import roc_auc_score
 
14
 
15
  from .constants import TARGET_NAMES
16
 
@@ -23,7 +24,29 @@ class TaskTrainingOutput:
23
  best_params: Dict
24
 
25
 
26
- def _sample_hyperparams(trial: optuna.Trial, base_params: Dict) -> Dict:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  params = dict(base_params)
28
  params.update(
29
  {
@@ -36,6 +59,7 @@ def _sample_hyperparams(trial: optuna.Trial, base_params: Dict) -> Dict:
36
  "bagging_freq": trial.suggest_int("bagging_freq", 1, 10),
37
  "reg_alpha": trial.suggest_float("reg_alpha", 1e-8, 10.0, log=True),
38
  "reg_lambda": trial.suggest_float("reg_lambda", 1e-8, 10.0, log=True),
 
39
  }
40
  )
41
  params.setdefault("objective", "binary")
@@ -52,7 +76,7 @@ def train_lightgbm_task(
52
  X_val: np.ndarray,
53
  y_val: np.ndarray,
54
  base_params: Dict,
55
- boosting_rounds: int,
56
  early_stopping_rounds: int,
57
  n_trials: int,
58
  seed: int,
@@ -61,8 +85,7 @@ def train_lightgbm_task(
61
  return None
62
 
63
  def objective(trial: optuna.Trial) -> float:
64
- params = _sample_hyperparams(trial, base_params)
65
- params["n_estimators"] = boosting_rounds
66
  params["random_state"] = seed
67
  model = lgb.LGBMClassifier(**params)
68
  model.fit(
@@ -77,17 +100,15 @@ def train_lightgbm_task(
77
  verbose=False,
78
  )
79
  ],
80
- verbose=False,
81
  )
82
- best_iter = getattr(model, "best_iteration_", boosting_rounds)
83
  preds = model.predict_proba(X_val, num_iteration=best_iter)[:, 1]
84
  return float(roc_auc_score(y_val, preds))
85
 
86
  study = optuna.create_study(direction="maximize")
87
  study.optimize(objective, n_trials=n_trials, show_progress_bar=False)
88
 
89
- best_params = _sample_hyperparams(study.best_trial, base_params)
90
- best_params["n_estimators"] = boosting_rounds
91
  best_params["random_state"] = seed
92
 
93
  final_model = lgb.LGBMClassifier(**best_params)
@@ -103,10 +124,9 @@ def train_lightgbm_task(
103
  verbose=False,
104
  )
105
  ],
106
- verbose=False,
107
  )
108
 
109
- best_iteration = getattr(final_model, "best_iteration_", boosting_rounds)
110
  val_preds = final_model.predict_proba(X_val, num_iteration=best_iteration)[:, 1]
111
  val_auc = roc_auc_score(y_val, val_preds)
112
 
@@ -139,12 +159,13 @@ def train_stage_one_models(
139
  training_cfg = config.get("training", {})
140
  base_params = training_cfg.get("lightgbm_params", {})
141
  n_trials = training_cfg.get("optuna_trials", 40)
142
- boosting_rounds = training_cfg.get("boosting_rounds", 1500)
143
  early_stopping = training_cfg.get("early_stopping_rounds", 100)
144
  seed = config.get("seed", 42)
145
 
 
146
  n_train = len(train_df)
147
- n_tasks = len(target_names)
148
 
149
  train_preds = np.full((n_train, n_tasks), 0.5, dtype=np.float32)
150
  val_preds = (
@@ -156,72 +177,74 @@ def train_stage_one_models(
156
  metrics: Dict[str, Dict] = {}
157
  params_dump: Dict[str, Dict] = {}
158
 
159
- for task_idx, task_name in enumerate(target_names):
160
- train_mask = train_df[task_name].notna().values
161
- if val_df is None or val_features is None:
162
- metrics[task_name] = {"status": "skipped", "reason": "missing validation split"}
163
- continue
164
-
165
- val_mask = val_df[task_name].notna().values
166
- if train_mask.sum() < 2 or val_mask.sum() < 2:
167
- metrics[task_name] = {"status": "skipped", "reason": "insufficient labeled data"}
168
- continue
169
-
170
- X_train_task = train_features[train_mask]
171
- y_train_task = train_df.loc[train_mask, task_name].astype(float).values
172
- X_val_task = val_features[val_mask]
173
- y_val_task = val_df.loc[val_mask, task_name].astype(float).values
174
-
175
- if len(np.unique(y_train_task)) < 2 or len(np.unique(y_val_task)) < 2:
176
- metrics[task_name] = {"status": "skipped", "reason": "single-class labels"}
177
- continue
178
-
179
- task_result = train_lightgbm_task(
180
- X_train_task,
181
- y_train_task,
182
- X_val_task,
183
- y_val_task,
184
- base_params=base_params,
185
- boosting_rounds=boosting_rounds,
186
- early_stopping_rounds=early_stopping,
187
- n_trials=n_trials,
188
- seed=seed,
189
- )
190
-
191
- if task_result is None:
192
- metrics[task_name] = {"status": "skipped", "reason": "training failed"}
193
- continue
194
 
195
- model = task_result.model
196
- best_iter = task_result.best_iteration
 
197
 
198
- model_path = stage_dir / f"{task_name}.pkl"
199
- joblib.dump(model, model_path)
200
 
201
- params_dump[task_name] = {
202
- **task_result.best_params,
203
- "best_iteration": best_iter,
204
- "val_auc": task_result.val_auc,
205
- }
206
 
207
- full_train_preds = model.predict_proba(
208
- train_features,
209
- num_iteration=best_iter,
210
- )[:, 1]
211
- train_preds[:, task_idx] = full_train_preds.astype(np.float32)
212
 
213
- if val_preds is not None:
214
- full_val_preds = model.predict_proba(
215
- val_features,
216
  num_iteration=best_iter,
217
  )[:, 1]
218
- val_preds[:, task_idx] = full_val_preds.astype(np.float32)
219
-
220
- metrics[task_name] = {
221
- "val_auc": task_result.val_auc,
222
- "n_train_samples": int(train_mask.sum()),
223
- "n_val_samples": int(val_mask.sum()),
224
- }
 
 
 
 
 
 
 
225
 
226
  save_stage_metrics(metrics, checkpoint_dir / "metrics_stage1.json")
227
  params_path = checkpoint_dir / "stage1_params.json"
 
11
  import optuna
12
  import pandas as pd
13
  from sklearn.metrics import roc_auc_score
14
+ from tqdm import tqdm
15
 
16
  from .constants import TARGET_NAMES
17
 
 
24
  best_params: Dict
25
 
26
 
27
+ def resolve_n_estimators(training_cfg: Dict) -> Sequence[int]:
28
+ """Normalize the n_estimators config entry into a non-empty list of ints."""
29
+ if "n_estimators" in training_cfg:
30
+ raw_value = training_cfg["n_estimators"]
31
+ elif "boosting_rounds" in training_cfg:
32
+ raw_value = training_cfg["boosting_rounds"]
33
+ else:
34
+ raw_value = [50, 500, 1000]
35
+
36
+ if isinstance(raw_value, int):
37
+ choices = [int(raw_value)]
38
+ elif isinstance(raw_value, Sequence) and not isinstance(raw_value, (str, bytes)):
39
+ choices = [int(v) for v in raw_value]
40
+ else:
41
+ raise ValueError("training.n_estimators must be an int or a sequence of ints")
42
+
43
+ choices = [v for v in choices if v > 0]
44
+ if not choices:
45
+ raise ValueError("training.n_estimators must contain at least one positive value")
46
+ return choices
47
+
48
+
49
+ def _sample_hyperparams(trial: optuna.Trial, base_params: Dict, n_estimators_choices: Sequence[int]) -> Dict:
50
  params = dict(base_params)
51
  params.update(
52
  {
 
59
  "bagging_freq": trial.suggest_int("bagging_freq", 1, 10),
60
  "reg_alpha": trial.suggest_float("reg_alpha", 1e-8, 10.0, log=True),
61
  "reg_lambda": trial.suggest_float("reg_lambda", 1e-8, 10.0, log=True),
62
+ "n_estimators": trial.suggest_categorical("n_estimators", list(n_estimators_choices)),
63
  }
64
  )
65
  params.setdefault("objective", "binary")
 
76
  X_val: np.ndarray,
77
  y_val: np.ndarray,
78
  base_params: Dict,
79
+ n_estimators_choices: Sequence[int],
80
  early_stopping_rounds: int,
81
  n_trials: int,
82
  seed: int,
 
85
  return None
86
 
87
  def objective(trial: optuna.Trial) -> float:
88
+ params = _sample_hyperparams(trial, base_params, n_estimators_choices)
 
89
  params["random_state"] = seed
90
  model = lgb.LGBMClassifier(**params)
91
  model.fit(
 
100
  verbose=False,
101
  )
102
  ],
 
103
  )
104
+ best_iter = getattr(model, "best_iteration_", params["n_estimators"])
105
  preds = model.predict_proba(X_val, num_iteration=best_iter)[:, 1]
106
  return float(roc_auc_score(y_val, preds))
107
 
108
  study = optuna.create_study(direction="maximize")
109
  study.optimize(objective, n_trials=n_trials, show_progress_bar=False)
110
 
111
+ best_params = _sample_hyperparams(study.best_trial, base_params, n_estimators_choices)
 
112
  best_params["random_state"] = seed
113
 
114
  final_model = lgb.LGBMClassifier(**best_params)
 
124
  verbose=False,
125
  )
126
  ],
 
127
  )
128
 
129
+ best_iteration = getattr(final_model, "best_iteration_", best_params["n_estimators"])
130
  val_preds = final_model.predict_proba(X_val, num_iteration=best_iteration)[:, 1]
131
  val_auc = roc_auc_score(y_val, val_preds)
132
 
 
159
  training_cfg = config.get("training", {})
160
  base_params = training_cfg.get("lightgbm_params", {})
161
  n_trials = training_cfg.get("optuna_trials", 40)
162
+ n_estimators_choices = resolve_n_estimators(training_cfg)
163
  early_stopping = training_cfg.get("early_stopping_rounds", 100)
164
  seed = config.get("seed", 42)
165
 
166
+ task_list = list(target_names)
167
  n_train = len(train_df)
168
+ n_tasks = len(task_list)
169
 
170
  train_preds = np.full((n_train, n_tasks), 0.5, dtype=np.float32)
171
  val_preds = (
 
177
  metrics: Dict[str, Dict] = {}
178
  params_dump: Dict[str, Dict] = {}
179
 
180
+ with tqdm(task_list, desc="Stage 1", unit="task") as progress_bar:
181
+ for task_idx, task_name in enumerate(progress_bar):
182
+ progress_bar.set_postfix(task=task_name)
183
+ train_mask = train_df[task_name].notna().values
184
+ if val_df is None or val_features is None:
185
+ metrics[task_name] = {"status": "skipped", "reason": "missing validation split"}
186
+ continue
187
+
188
+ val_mask = val_df[task_name].notna().values
189
+ if train_mask.sum() < 2 or val_mask.sum() < 2:
190
+ metrics[task_name] = {"status": "skipped", "reason": "insufficient labeled data"}
191
+ continue
192
+
193
+ X_train_task = train_features[train_mask]
194
+ y_train_task = train_df.loc[train_mask, task_name].astype(float).values
195
+ X_val_task = val_features[val_mask]
196
+ y_val_task = val_df.loc[val_mask, task_name].astype(float).values
197
+
198
+ if len(np.unique(y_train_task)) < 2 or len(np.unique(y_val_task)) < 2:
199
+ metrics[task_name] = {"status": "skipped", "reason": "single-class labels"}
200
+ continue
201
+
202
+ task_result = train_lightgbm_task(
203
+ X_train_task,
204
+ y_train_task,
205
+ X_val_task,
206
+ y_val_task,
207
+ base_params=base_params,
208
+ n_estimators_choices=n_estimators_choices,
209
+ early_stopping_rounds=early_stopping,
210
+ n_trials=n_trials,
211
+ seed=seed,
212
+ )
 
 
213
 
214
+ if task_result is None:
215
+ metrics[task_name] = {"status": "skipped", "reason": "training failed"}
216
+ continue
217
 
218
+ model = task_result.model
219
+ best_iter = task_result.best_iteration
220
 
221
+ model_path = stage_dir / f"{task_name}.pkl"
222
+ joblib.dump(model, model_path)
 
 
 
223
 
224
+ params_dump[task_name] = {
225
+ **task_result.best_params,
226
+ "best_iteration": best_iter,
227
+ "val_auc": task_result.val_auc,
228
+ }
229
 
230
+ full_train_preds = model.predict_proba(
231
+ train_features,
 
232
  num_iteration=best_iter,
233
  )[:, 1]
234
+ train_preds[:, task_idx] = full_train_preds.astype(np.float32)
235
+
236
+ if val_preds is not None:
237
+ full_val_preds = model.predict_proba(
238
+ val_features,
239
+ num_iteration=best_iter,
240
+ )[:, 1]
241
+ val_preds[:, task_idx] = full_val_preds.astype(np.float32)
242
+
243
+ metrics[task_name] = {
244
+ "val_auc": task_result.val_auc,
245
+ "n_train_samples": int(train_mask.sum()),
246
+ "n_val_samples": int(val_mask.sum()),
247
+ }
248
 
249
  save_stage_metrics(metrics, checkpoint_dir / "metrics_stage1.json")
250
  params_path = checkpoint_dir / "stage1_params.json"
src/stage_two.py CHANGED
@@ -6,10 +6,10 @@ from typing import Dict, Optional, Sequence
6
  import joblib
7
  import numpy as np
8
  import pandas as pd
9
- from sklearn.metrics import roc_auc_score
10
 
11
  from .constants import TARGET_NAMES
12
- from .lightgbm_trainer import save_stage_metrics, train_lightgbm_task
13
 
14
 
15
  def _build_augmented_matrix(base_features: np.ndarray, prediction_matrix: np.ndarray, target_idx: int) -> np.ndarray:
@@ -32,75 +32,75 @@ def train_stage_two_models(
32
  training_cfg = config.get("training", {})
33
  base_params = training_cfg.get("lightgbm_params", {})
34
  n_trials = training_cfg.get("optuna_trials", 40)
35
- boosting_rounds = training_cfg.get("boosting_rounds", 1500)
36
  early_stopping = training_cfg.get("early_stopping_rounds", 100)
37
  seed = config.get("seed", 42)
38
 
39
  stage_dir = checkpoint_dir / "stage2"
40
  stage_dir.mkdir(parents=True, exist_ok=True)
41
 
42
- n_train = len(train_df)
43
- n_val = len(val_df) if val_df is not None else 0
44
-
45
  metrics: Dict[str, Dict] = {}
46
-
47
- for task_idx, task_name in enumerate(target_names):
48
- mask = train_df[task_name].notna().values
49
- if mask.sum() == 0:
50
- metrics[task_name] = {"status": "skipped", "reason": "no labels"}
51
- continue
52
-
53
- augmented_train_matrix = _build_augmented_matrix(
54
- train_features[mask],
55
- stage1_train_preds[mask],
56
- task_idx,
57
- )
58
- y_train = train_df.loc[mask, task_name].astype(float).values
59
-
60
- if (
61
- val_features is None
62
- or val_df is None
63
- or stage1_val_preds is None
64
- or val_df[task_name].notna().sum() < 2
65
- ):
66
- metrics[task_name] = {"status": "skipped", "reason": "missing validation data"}
67
- continue
68
-
69
- val_mask = val_df[task_name].notna().values
70
- augmented_val_matrix = _build_augmented_matrix(
71
- val_features[val_mask],
72
- stage1_val_preds[val_mask],
73
- task_idx,
74
- )
75
- y_val = val_df.loc[val_mask, task_name].astype(float).values
76
-
77
- if len(np.unique(y_val)) < 2 or len(np.unique(y_train)) < 2:
78
- metrics[task_name] = {"status": "skipped", "reason": "single-class labels"}
79
- continue
80
-
81
- task_result = train_lightgbm_task(
82
- augmented_train_matrix,
83
- y_train,
84
- augmented_val_matrix,
85
- y_val,
86
- base_params=base_params,
87
- boosting_rounds=boosting_rounds,
88
- early_stopping_rounds=early_stopping,
89
- n_trials=n_trials,
90
- seed=seed,
91
- )
92
-
93
- if task_result is None:
94
- metrics[task_name] = {"status": "skipped", "reason": "training failed"}
95
- continue
96
-
97
- model_path = stage_dir / f"{task_name}.pkl"
98
- joblib.dump(task_result.model, model_path)
99
-
100
- metrics[task_name] = {
101
- "val_auc": task_result.val_auc,
102
- "best_iteration": int(task_result.best_iteration),
103
- }
 
 
 
104
 
105
  save_stage_metrics(metrics, checkpoint_dir / "metrics_stage2.json")
106
  return {"metrics": metrics}
 
6
  import joblib
7
  import numpy as np
8
  import pandas as pd
9
+ from tqdm import tqdm
10
 
11
  from .constants import TARGET_NAMES
12
+ from .lightgbm_trainer import resolve_n_estimators, save_stage_metrics, train_lightgbm_task
13
 
14
 
15
  def _build_augmented_matrix(base_features: np.ndarray, prediction_matrix: np.ndarray, target_idx: int) -> np.ndarray:
 
32
  training_cfg = config.get("training", {})
33
  base_params = training_cfg.get("lightgbm_params", {})
34
  n_trials = training_cfg.get("optuna_trials", 40)
35
+ n_estimators_choices = resolve_n_estimators(training_cfg)
36
  early_stopping = training_cfg.get("early_stopping_rounds", 100)
37
  seed = config.get("seed", 42)
38
 
39
  stage_dir = checkpoint_dir / "stage2"
40
  stage_dir.mkdir(parents=True, exist_ok=True)
41
 
 
 
 
42
  metrics: Dict[str, Dict] = {}
43
+ task_list = list(target_names)
44
+
45
+ with tqdm(task_list, desc="Stage 2", unit="task") as progress_bar:
46
+ for task_idx, task_name in enumerate(progress_bar):
47
+ progress_bar.set_postfix(task=task_name)
48
+ mask = train_df[task_name].notna().values
49
+ if mask.sum() == 0:
50
+ metrics[task_name] = {"status": "skipped", "reason": "no labels"}
51
+ continue
52
+
53
+ augmented_train_matrix = _build_augmented_matrix(
54
+ train_features[mask],
55
+ stage1_train_preds[mask],
56
+ task_idx,
57
+ )
58
+ y_train = train_df.loc[mask, task_name].astype(float).values
59
+
60
+ if (
61
+ val_features is None
62
+ or val_df is None
63
+ or stage1_val_preds is None
64
+ or val_df[task_name].notna().sum() < 2
65
+ ):
66
+ metrics[task_name] = {"status": "skipped", "reason": "missing validation data"}
67
+ continue
68
+
69
+ val_mask = val_df[task_name].notna().values
70
+ augmented_val_matrix = _build_augmented_matrix(
71
+ val_features[val_mask],
72
+ stage1_val_preds[val_mask],
73
+ task_idx,
74
+ )
75
+ y_val = val_df.loc[val_mask, task_name].astype(float).values
76
+
77
+ if len(np.unique(y_val)) < 2 or len(np.unique(y_train)) < 2:
78
+ metrics[task_name] = {"status": "skipped", "reason": "single-class labels"}
79
+ continue
80
+
81
+ task_result = train_lightgbm_task(
82
+ augmented_train_matrix,
83
+ y_train,
84
+ augmented_val_matrix,
85
+ y_val,
86
+ base_params=base_params,
87
+ n_estimators_choices=n_estimators_choices,
88
+ early_stopping_rounds=early_stopping,
89
+ n_trials=n_trials,
90
+ seed=seed,
91
+ )
92
+
93
+ if task_result is None:
94
+ metrics[task_name] = {"status": "skipped", "reason": "training failed"}
95
+ continue
96
+
97
+ model_path = stage_dir / f"{task_name}.pkl"
98
+ joblib.dump(task_result.model, model_path)
99
+
100
+ metrics[task_name] = {
101
+ "val_auc": task_result.val_auc,
102
+ "best_iteration": int(task_result.best_iteration),
103
+ }
104
 
105
  save_stage_metrics(metrics, checkpoint_dir / "metrics_stage2.json")
106
  return {"metrics": metrics}