Spaces:
Sleeping
Sleeping
Maximilian Schuh
commited on
Commit
·
759324e
1
Parent(s):
e1aa0ed
Added new files and updated learning
Browse files- README.md +2 -33
- checkpoints/cache/stage1_train_predictions.npz +3 -0
- checkpoints/cache/stage1_validation_predictions.npz +3 -0
- checkpoints/cache/train_ecfp.npz +3 -0
- checkpoints/cache/validation_ecfp.npz +3 -0
- checkpoints/metrics_stage1.json +62 -0
- checkpoints/metrics_stage2.json +50 -0
- checkpoints/stage1/NR-AR-LBD.pkl +3 -0
- checkpoints/stage1/NR-AR.pkl +3 -0
- checkpoints/stage1/NR-AhR.pkl +3 -0
- checkpoints/stage1/NR-Aromatase.pkl +3 -0
- checkpoints/stage1/NR-ER-LBD.pkl +3 -0
- checkpoints/stage1/NR-ER.pkl +3 -0
- checkpoints/stage1/NR-PPAR-gamma.pkl +3 -0
- checkpoints/stage1/SR-ARE.pkl +3 -0
- checkpoints/stage1/SR-ATAD5.pkl +3 -0
- checkpoints/stage1/SR-HSE.pkl +3 -0
- checkpoints/stage1/SR-MMP.pkl +3 -0
- checkpoints/stage1/SR-p53.pkl +3 -0
- checkpoints/stage1_params.json +242 -0
- checkpoints/stage2/NR-AR-LBD.pkl +3 -0
- checkpoints/stage2/NR-AR.pkl +3 -0
- checkpoints/stage2/NR-AhR.pkl +3 -0
- checkpoints/stage2/NR-Aromatase.pkl +3 -0
- checkpoints/stage2/NR-ER-LBD.pkl +3 -0
- checkpoints/stage2/NR-ER.pkl +3 -0
- checkpoints/stage2/NR-PPAR-gamma.pkl +3 -0
- checkpoints/stage2/SR-ARE.pkl +3 -0
- checkpoints/stage2/SR-ATAD5.pkl +3 -0
- checkpoints/stage2/SR-HSE.pkl +3 -0
- checkpoints/stage2/SR-MMP.pkl +3 -0
- checkpoints/stage2/SR-p53.pkl +3 -0
- checkpoints/training_manifest.json +41 -0
- config/config.json +2 -2
- requirements.txt +1 -0
- src/lightgbm_trainer.py +94 -71
- src/stage_two.py +64 -64
README.md
CHANGED
|
@@ -20,30 +20,6 @@ MultiTaskTox is a two-stage Gradient Boosting workflow purpose-built for the [To
|
|
| 20 |
- **Multitask enhancement** – stage two augments the fingerprint vector with the predictions of the other tasks, capturing label correlations without building a fully joint model.
|
| 21 |
- **Leaderboard-ready interface** – `train.py` produces checkpoints and metadata under `checkpoints/`, while `predict.py` exposes the required `predict(smiles_list)` signature.
|
| 22 |
|
| 23 |
-
## Installation
|
| 24 |
-
|
| 25 |
-
```bash
|
| 26 |
-
git clone https://huggingface.co/spaces/ml-jku/tox21_gin_classifier
|
| 27 |
-
cd tox21_gin_classifier
|
| 28 |
-
python -m venv .venv && source .venv/bin/activate
|
| 29 |
-
pip install --upgrade pip
|
| 30 |
-
pip install -r requirements.txt
|
| 31 |
-
```
|
| 32 |
-
|
| 33 |
-
The requirements include RDKit, LightGBM, Optuna, and the MAP4 fingerprint package so you can switch feature types via the config.
|
| 34 |
-
|
| 35 |
-
## Training
|
| 36 |
-
|
| 37 |
-
1. Create a `.env` file (all Hugging Face Spaces support secrets) with your dataset token:
|
| 38 |
-
```
|
| 39 |
-
TOKEN=hf_xxx
|
| 40 |
-
```
|
| 41 |
-
2. Adjust `config/config.json` if needed (fingerprint type, Optuna trial count, etc.).
|
| 42 |
-
3. Run:
|
| 43 |
-
```bash
|
| 44 |
-
python train.py
|
| 45 |
-
```
|
| 46 |
-
|
| 47 |
### What `train.py` does
|
| 48 |
|
| 49 |
1. Loads the predefined `train` and `validation` splits from the Tox21 dataset.
|
|
@@ -88,7 +64,7 @@ The function:
|
|
| 88 |
},
|
| 89 |
"training": {
|
| 90 |
"optuna_trials": 40,
|
| 91 |
-
"
|
| 92 |
"early_stopping_rounds": 100,
|
| 93 |
"lightgbm_params": {
|
| 94 |
"objective": "binary",
|
|
@@ -104,6 +80,7 @@ The function:
|
|
| 104 |
- Switch `features.type` to `"map4"` to use MAP4 fingerprints (installed by default).
|
| 105 |
- Disable multitask behavior by setting `"multitask": {"enabled": false}`.
|
| 106 |
- Increase `optuna_trials` for a more exhaustive search if compute allows.
|
|
|
|
| 107 |
|
| 108 |
## Repository Layout
|
| 109 |
|
|
@@ -116,11 +93,3 @@ The function:
|
|
| 116 |
- `src/constants.py`, `src/seed.py` – shared utilities.
|
| 117 |
- `docs/proposed_lightgbm_framework.md` – detailed design notes for the workflow.
|
| 118 |
- `checkpoints/` – default output directory containing models, metrics, caches, and the training manifest used at inference time.
|
| 119 |
-
|
| 120 |
-
## Tips
|
| 121 |
-
|
| 122 |
-
- Training relies on the `TOKEN` environment variable to access the Tox21 dataset on Hugging Face. Locally you can omit it if the dataset is public for your account.
|
| 123 |
-
- MAP4 fingerprints are more expensive to compute; enable the cache directory to avoid recomputation across runs.
|
| 124 |
-
- Use the saved metrics files to compare stage-one vs. stage-two AUCs and to trace which configuration produced a set of checkpoints.
|
| 125 |
-
|
| 126 |
-
Happy modeling! If you extend MultiTaskTox (new fingerprints, alternative learners, etc.), keep the `predict(smiles)` contract intact so your Space remains leaderboard compatible.
|
|
|
|
| 20 |
- **Multitask enhancement** – stage two augments the fingerprint vector with the predictions of the other tasks, capturing label correlations without building a fully joint model.
|
| 21 |
- **Leaderboard-ready interface** – `train.py` produces checkpoints and metadata under `checkpoints/`, while `predict.py` exposes the required `predict(smiles_list)` signature.
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
### What `train.py` does
|
| 24 |
|
| 25 |
1. Loads the predefined `train` and `validation` splits from the Tox21 dataset.
|
|
|
|
| 64 |
},
|
| 65 |
"training": {
|
| 66 |
"optuna_trials": 40,
|
| 67 |
+
"n_estimators": [50, 500, 1000],
|
| 68 |
"early_stopping_rounds": 100,
|
| 69 |
"lightgbm_params": {
|
| 70 |
"objective": "binary",
|
|
|
|
| 80 |
- Switch `features.type` to `"map4"` to use MAP4 fingerprints (installed by default).
|
| 81 |
- Disable multitask behavior by setting `"multitask": {"enabled": false}`.
|
| 82 |
- Increase `optuna_trials` for a more exhaustive search if compute allows.
|
| 83 |
+
- Set `training.n_estimators` to either a single integer or a list of candidate values (default `[50, 500, 1000]`) to control the Optuna search space for the `n_estimators` hyperparameter.
|
| 84 |
|
| 85 |
## Repository Layout
|
| 86 |
|
|
|
|
| 93 |
- `src/constants.py`, `src/seed.py` – shared utilities.
|
| 94 |
- `docs/proposed_lightgbm_framework.md` – detailed design notes for the workflow.
|
| 95 |
- `checkpoints/` – default output directory containing models, metrics, caches, and the training manifest used at inference time.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
checkpoints/cache/stage1_train_predictions.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:019864a99cd005c40bc979da6ec6594398a02fdeaaedb943ea336b7603efab49
|
| 3 |
+
size 565279
|
checkpoints/cache/stage1_validation_predictions.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:eab5701f117dbe4f1bd0755d870e1f5518e9fb99eae505c459af3576534fb843
|
| 3 |
+
size 15007
|
checkpoints/cache/train_ecfp.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dfc0c773eb523d256b84fec1cc314f290bcc5b9288073aff576712437450bf5a
|
| 3 |
+
size 48726858
|
checkpoints/cache/validation_ecfp.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:52f576b159c6e09d1ca22f0e260256cd46deacb137dfe77f127c756117671830
|
| 3 |
+
size 1224294
|
checkpoints/metrics_stage1.json
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"NR-AhR": {
|
| 3 |
+
"val_auc": 0.8789764868603043,
|
| 4 |
+
"n_train_samples": 8165,
|
| 5 |
+
"n_val_samples": 271
|
| 6 |
+
},
|
| 7 |
+
"NR-AR": {
|
| 8 |
+
"val_auc": 0.8547453703703702,
|
| 9 |
+
"n_train_samples": 9358,
|
| 10 |
+
"n_val_samples": 291
|
| 11 |
+
},
|
| 12 |
+
"NR-AR-LBD": {
|
| 13 |
+
"val_auc": 0.9717741935483871,
|
| 14 |
+
"n_train_samples": 8595,
|
| 15 |
+
"n_val_samples": 252
|
| 16 |
+
},
|
| 17 |
+
"NR-Aromatase": {
|
| 18 |
+
"val_auc": 0.8483560090702948,
|
| 19 |
+
"n_train_samples": 7222,
|
| 20 |
+
"n_val_samples": 214
|
| 21 |
+
},
|
| 22 |
+
"NR-ER": {
|
| 23 |
+
"val_auc": 0.7974683544303797,
|
| 24 |
+
"n_train_samples": 7694,
|
| 25 |
+
"n_val_samples": 264
|
| 26 |
+
},
|
| 27 |
+
"NR-ER-LBD": {
|
| 28 |
+
"val_auc": 0.8347826086956521,
|
| 29 |
+
"n_train_samples": 8749,
|
| 30 |
+
"n_val_samples": 286
|
| 31 |
+
},
|
| 32 |
+
"NR-PPAR-gamma": {
|
| 33 |
+
"val_auc": 0.8077025232403718,
|
| 34 |
+
"n_train_samples": 8180,
|
| 35 |
+
"n_val_samples": 266
|
| 36 |
+
},
|
| 37 |
+
"SR-ARE": {
|
| 38 |
+
"val_auc": 0.8194921070693205,
|
| 39 |
+
"n_train_samples": 7165,
|
| 40 |
+
"n_val_samples": 233
|
| 41 |
+
},
|
| 42 |
+
"SR-ATAD5": {
|
| 43 |
+
"val_auc": 0.8749593495934959,
|
| 44 |
+
"n_train_samples": 9087,
|
| 45 |
+
"n_val_samples": 271
|
| 46 |
+
},
|
| 47 |
+
"SR-HSE": {
|
| 48 |
+
"val_auc": 0.92421875,
|
| 49 |
+
"n_train_samples": 8147,
|
| 50 |
+
"n_val_samples": 266
|
| 51 |
+
},
|
| 52 |
+
"SR-MMP": {
|
| 53 |
+
"val_auc": 0.9280613594287226,
|
| 54 |
+
"n_train_samples": 7317,
|
| 55 |
+
"n_val_samples": 237
|
| 56 |
+
},
|
| 57 |
+
"SR-p53": {
|
| 58 |
+
"val_auc": 0.8038690476190476,
|
| 59 |
+
"n_train_samples": 8630,
|
| 60 |
+
"n_val_samples": 268
|
| 61 |
+
}
|
| 62 |
+
}
|
checkpoints/metrics_stage2.json
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"NR-AhR": {
|
| 3 |
+
"val_auc": 0.8740663900414938,
|
| 4 |
+
"best_iteration": 5
|
| 5 |
+
},
|
| 6 |
+
"NR-AR": {
|
| 7 |
+
"val_auc": 0.8854166666666666,
|
| 8 |
+
"best_iteration": 204
|
| 9 |
+
},
|
| 10 |
+
"NR-AR-LBD": {
|
| 11 |
+
"val_auc": 0.9828629032258065,
|
| 12 |
+
"best_iteration": 23
|
| 13 |
+
},
|
| 14 |
+
"NR-Aromatase": {
|
| 15 |
+
"val_auc": 0.8572845804988662,
|
| 16 |
+
"best_iteration": 8
|
| 17 |
+
},
|
| 18 |
+
"NR-ER": {
|
| 19 |
+
"val_auc": 0.8123144241287701,
|
| 20 |
+
"best_iteration": 2
|
| 21 |
+
},
|
| 22 |
+
"NR-ER-LBD": {
|
| 23 |
+
"val_auc": 0.857608695652174,
|
| 24 |
+
"best_iteration": 4
|
| 25 |
+
},
|
| 26 |
+
"NR-PPAR-gamma": {
|
| 27 |
+
"val_auc": 0.8739707835325365,
|
| 28 |
+
"best_iteration": 13
|
| 29 |
+
},
|
| 30 |
+
"SR-ARE": {
|
| 31 |
+
"val_auc": 0.857698467169984,
|
| 32 |
+
"best_iteration": 4
|
| 33 |
+
},
|
| 34 |
+
"SR-ATAD5": {
|
| 35 |
+
"val_auc": 0.8530081300813008,
|
| 36 |
+
"best_iteration": 145
|
| 37 |
+
},
|
| 38 |
+
"SR-HSE": {
|
| 39 |
+
"val_auc": 0.9201171875,
|
| 40 |
+
"best_iteration": 2
|
| 41 |
+
},
|
| 42 |
+
"SR-MMP": {
|
| 43 |
+
"val_auc": 0.9312351229833377,
|
| 44 |
+
"best_iteration": 395
|
| 45 |
+
},
|
| 46 |
+
"SR-p53": {
|
| 47 |
+
"val_auc": 0.8722470238095239,
|
| 48 |
+
"best_iteration": 3
|
| 49 |
+
}
|
| 50 |
+
}
|
checkpoints/stage1/NR-AR-LBD.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3d1b4030d07f948b2b549fc6426587a05c269648fbc40131d6e99241c8972d28
|
| 3 |
+
size 627492
|
checkpoints/stage1/NR-AR.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:901eb451790a2607cc3e3c6eb6328f947899b557ec9887adc2481aea5a95433e
|
| 3 |
+
size 24772
|
checkpoints/stage1/NR-AhR.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5d2afa34f2143ab1d38247ae82d9c859ea74c37799134c571e30a3ec3e0cfe10
|
| 3 |
+
size 1001188
|
checkpoints/stage1/NR-Aromatase.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cf3aef5acc8411f33bbfe184a45ecb8b292eee93260f1b688a82e2ba26e83ae6
|
| 3 |
+
size 428964
|
checkpoints/stage1/NR-ER-LBD.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0d42aa1195696a8bab89cf0a6ad5914cbfb198d8e946d316dbec3df22621181b
|
| 3 |
+
size 44628
|
checkpoints/stage1/NR-ER.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8be3a13fc68d2c96b827dab8771026de3944e400150b3b173cc77aa23eec328e
|
| 3 |
+
size 113444
|
checkpoints/stage1/NR-PPAR-gamma.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c6cbc9dce92de09a74df3df581de742044f1ae7d522a5018d8006474d089768d
|
| 3 |
+
size 58420
|
checkpoints/stage1/SR-ARE.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:47fc92884f6ed03d8738f58b809a6c616d13dc7322249008941fbd2506f834bf
|
| 3 |
+
size 392756
|
checkpoints/stage1/SR-ATAD5.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d00260b6422786faed1efafb99deca01be4021bc7bf4b6da8e807beacac6e14a
|
| 3 |
+
size 375092
|
checkpoints/stage1/SR-HSE.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a7b67c8aba900b7036a4d41e1171d3dabee2423fb1a08a26accf5b3d533065da
|
| 3 |
+
size 237860
|
checkpoints/stage1/SR-MMP.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:528a7efd2fe9748037da657d88959821f994dfbe35cdbec0a4c415bd3a9ed898
|
| 3 |
+
size 606164
|
checkpoints/stage1/SR-p53.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9c2a2dceb70d39f5d72cf68f104be68c763dcf25da33b41bd45abcbcc2e07906
|
| 3 |
+
size 1101540
|
checkpoints/stage1_params.json
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"NR-AhR": {
|
| 3 |
+
"objective": "binary",
|
| 4 |
+
"metric": "auc",
|
| 5 |
+
"verbosity": -1,
|
| 6 |
+
"learning_rate": 0.04596965586436891,
|
| 7 |
+
"num_leaves": 26,
|
| 8 |
+
"max_depth": 11,
|
| 9 |
+
"min_child_samples": 14,
|
| 10 |
+
"feature_fraction": 0.9910514912963474,
|
| 11 |
+
"bagging_fraction": 0.5143055600954589,
|
| 12 |
+
"bagging_freq": 10,
|
| 13 |
+
"reg_alpha": 1.514871105815587e-07,
|
| 14 |
+
"reg_lambda": 1.0837426184192575e-05,
|
| 15 |
+
"n_estimators": 500,
|
| 16 |
+
"boosting_type": "gbdt",
|
| 17 |
+
"n_jobs": -1,
|
| 18 |
+
"random_state": 42,
|
| 19 |
+
"best_iteration": 321,
|
| 20 |
+
"val_auc": 0.8789764868603043
|
| 21 |
+
},
|
| 22 |
+
"NR-AR": {
|
| 23 |
+
"objective": "binary",
|
| 24 |
+
"metric": "auc",
|
| 25 |
+
"verbosity": -1,
|
| 26 |
+
"learning_rate": 0.17083936308641903,
|
| 27 |
+
"num_leaves": 80,
|
| 28 |
+
"max_depth": 3,
|
| 29 |
+
"min_child_samples": 91,
|
| 30 |
+
"feature_fraction": 0.9278238464266182,
|
| 31 |
+
"bagging_fraction": 0.6395695127332895,
|
| 32 |
+
"bagging_freq": 7,
|
| 33 |
+
"reg_alpha": 5.1826386659948485,
|
| 34 |
+
"reg_lambda": 0.1806421496400551,
|
| 35 |
+
"n_estimators": 500,
|
| 36 |
+
"boosting_type": "gbdt",
|
| 37 |
+
"n_jobs": -1,
|
| 38 |
+
"random_state": 42,
|
| 39 |
+
"best_iteration": 1,
|
| 40 |
+
"val_auc": 0.8547453703703702
|
| 41 |
+
},
|
| 42 |
+
"NR-AR-LBD": {
|
| 43 |
+
"objective": "binary",
|
| 44 |
+
"metric": "auc",
|
| 45 |
+
"verbosity": -1,
|
| 46 |
+
"learning_rate": 0.1992149524787967,
|
| 47 |
+
"num_leaves": 180,
|
| 48 |
+
"max_depth": 12,
|
| 49 |
+
"min_child_samples": 49,
|
| 50 |
+
"feature_fraction": 0.5650870342995183,
|
| 51 |
+
"bagging_fraction": 0.5114405817018007,
|
| 52 |
+
"bagging_freq": 2,
|
| 53 |
+
"reg_alpha": 0.00012622761264931666,
|
| 54 |
+
"reg_lambda": 1.2770493065015112e-08,
|
| 55 |
+
"n_estimators": 500,
|
| 56 |
+
"boosting_type": "gbdt",
|
| 57 |
+
"n_jobs": -1,
|
| 58 |
+
"random_state": 42,
|
| 59 |
+
"best_iteration": 210,
|
| 60 |
+
"val_auc": 0.9717741935483871
|
| 61 |
+
},
|
| 62 |
+
"NR-Aromatase": {
|
| 63 |
+
"objective": "binary",
|
| 64 |
+
"metric": "auc",
|
| 65 |
+
"verbosity": -1,
|
| 66 |
+
"learning_rate": 0.07977276592360394,
|
| 67 |
+
"num_leaves": 225,
|
| 68 |
+
"max_depth": 10,
|
| 69 |
+
"min_child_samples": 10,
|
| 70 |
+
"feature_fraction": 0.9900155440745377,
|
| 71 |
+
"bagging_fraction": 0.8119652252471269,
|
| 72 |
+
"bagging_freq": 2,
|
| 73 |
+
"reg_alpha": 2.594339538659424e-06,
|
| 74 |
+
"reg_lambda": 5.04193406028159e-07,
|
| 75 |
+
"n_estimators": 1000,
|
| 76 |
+
"boosting_type": "gbdt",
|
| 77 |
+
"n_jobs": -1,
|
| 78 |
+
"random_state": 42,
|
| 79 |
+
"best_iteration": 62,
|
| 80 |
+
"val_auc": 0.8483560090702948
|
| 81 |
+
},
|
| 82 |
+
"NR-ER": {
|
| 83 |
+
"objective": "binary",
|
| 84 |
+
"metric": "auc",
|
| 85 |
+
"verbosity": -1,
|
| 86 |
+
"learning_rate": 0.011154410439266932,
|
| 87 |
+
"num_leaves": 193,
|
| 88 |
+
"max_depth": 0,
|
| 89 |
+
"min_child_samples": 13,
|
| 90 |
+
"feature_fraction": 0.976937312627059,
|
| 91 |
+
"bagging_fraction": 0.8101945461275061,
|
| 92 |
+
"bagging_freq": 10,
|
| 93 |
+
"reg_alpha": 3.712452767192828e-05,
|
| 94 |
+
"reg_lambda": 2.4101531861104065e-07,
|
| 95 |
+
"n_estimators": 500,
|
| 96 |
+
"boosting_type": "gbdt",
|
| 97 |
+
"n_jobs": -1,
|
| 98 |
+
"random_state": 42,
|
| 99 |
+
"best_iteration": 4,
|
| 100 |
+
"val_auc": 0.7974683544303797
|
| 101 |
+
},
|
| 102 |
+
"NR-ER-LBD": {
|
| 103 |
+
"objective": "binary",
|
| 104 |
+
"metric": "auc",
|
| 105 |
+
"verbosity": -1,
|
| 106 |
+
"learning_rate": 0.002057546388429118,
|
| 107 |
+
"num_leaves": 46,
|
| 108 |
+
"max_depth": 7,
|
| 109 |
+
"min_child_samples": 183,
|
| 110 |
+
"feature_fraction": 0.9180640687080166,
|
| 111 |
+
"bagging_fraction": 0.5266816009973788,
|
| 112 |
+
"bagging_freq": 2,
|
| 113 |
+
"reg_alpha": 0.4206611876821779,
|
| 114 |
+
"reg_lambda": 0.024870165101451694,
|
| 115 |
+
"n_estimators": 500,
|
| 116 |
+
"boosting_type": "gbdt",
|
| 117 |
+
"n_jobs": -1,
|
| 118 |
+
"random_state": 42,
|
| 119 |
+
"best_iteration": 18,
|
| 120 |
+
"val_auc": 0.8347826086956521
|
| 121 |
+
},
|
| 122 |
+
"NR-PPAR-gamma": {
|
| 123 |
+
"objective": "binary",
|
| 124 |
+
"metric": "auc",
|
| 125 |
+
"verbosity": -1,
|
| 126 |
+
"learning_rate": 0.05900664732601199,
|
| 127 |
+
"num_leaves": 235,
|
| 128 |
+
"max_depth": 10,
|
| 129 |
+
"min_child_samples": 73,
|
| 130 |
+
"feature_fraction": 0.9286215997058814,
|
| 131 |
+
"bagging_fraction": 0.6099533117515117,
|
| 132 |
+
"bagging_freq": 8,
|
| 133 |
+
"reg_alpha": 0.0002410679912248092,
|
| 134 |
+
"reg_lambda": 2.2374585644452814e-05,
|
| 135 |
+
"n_estimators": 50,
|
| 136 |
+
"boosting_type": "gbdt",
|
| 137 |
+
"n_jobs": -1,
|
| 138 |
+
"random_state": 42,
|
| 139 |
+
"best_iteration": 16,
|
| 140 |
+
"val_auc": 0.8077025232403718
|
| 141 |
+
},
|
| 142 |
+
"SR-ARE": {
|
| 143 |
+
"objective": "binary",
|
| 144 |
+
"metric": "auc",
|
| 145 |
+
"verbosity": -1,
|
| 146 |
+
"learning_rate": 0.11683900433079997,
|
| 147 |
+
"num_leaves": 25,
|
| 148 |
+
"max_depth": 6,
|
| 149 |
+
"min_child_samples": 121,
|
| 150 |
+
"feature_fraction": 0.9155819290034379,
|
| 151 |
+
"bagging_fraction": 0.6064835300540737,
|
| 152 |
+
"bagging_freq": 6,
|
| 153 |
+
"reg_alpha": 0.031558826962596216,
|
| 154 |
+
"reg_lambda": 0.49750603290125384,
|
| 155 |
+
"n_estimators": 1000,
|
| 156 |
+
"boosting_type": "gbdt",
|
| 157 |
+
"n_jobs": -1,
|
| 158 |
+
"random_state": 42,
|
| 159 |
+
"best_iteration": 313,
|
| 160 |
+
"val_auc": 0.8194921070693205
|
| 161 |
+
},
|
| 162 |
+
"SR-ATAD5": {
|
| 163 |
+
"objective": "binary",
|
| 164 |
+
"metric": "auc",
|
| 165 |
+
"verbosity": -1,
|
| 166 |
+
"learning_rate": 0.1398668165886807,
|
| 167 |
+
"num_leaves": 71,
|
| 168 |
+
"max_depth": 12,
|
| 169 |
+
"min_child_samples": 96,
|
| 170 |
+
"feature_fraction": 0.86899957688569,
|
| 171 |
+
"bagging_fraction": 0.7807007020967096,
|
| 172 |
+
"bagging_freq": 10,
|
| 173 |
+
"reg_alpha": 1.010171177179027e-08,
|
| 174 |
+
"reg_lambda": 2.3747514557444565e-07,
|
| 175 |
+
"n_estimators": 500,
|
| 176 |
+
"boosting_type": "gbdt",
|
| 177 |
+
"n_jobs": -1,
|
| 178 |
+
"random_state": 42,
|
| 179 |
+
"best_iteration": 126,
|
| 180 |
+
"val_auc": 0.8749593495934959
|
| 181 |
+
},
|
| 182 |
+
"SR-HSE": {
|
| 183 |
+
"objective": "binary",
|
| 184 |
+
"metric": "auc",
|
| 185 |
+
"verbosity": -1,
|
| 186 |
+
"learning_rate": 0.15091207817136804,
|
| 187 |
+
"num_leaves": 246,
|
| 188 |
+
"max_depth": 0,
|
| 189 |
+
"min_child_samples": 19,
|
| 190 |
+
"feature_fraction": 0.7867053613239711,
|
| 191 |
+
"bagging_fraction": 0.7013484568271124,
|
| 192 |
+
"bagging_freq": 9,
|
| 193 |
+
"reg_alpha": 0.0006360962863973946,
|
| 194 |
+
"reg_lambda": 6.440534124809522e-05,
|
| 195 |
+
"n_estimators": 500,
|
| 196 |
+
"boosting_type": "gbdt",
|
| 197 |
+
"n_jobs": -1,
|
| 198 |
+
"random_state": 42,
|
| 199 |
+
"best_iteration": 14,
|
| 200 |
+
"val_auc": 0.92421875
|
| 201 |
+
},
|
| 202 |
+
"SR-MMP": {
|
| 203 |
+
"objective": "binary",
|
| 204 |
+
"metric": "auc",
|
| 205 |
+
"verbosity": -1,
|
| 206 |
+
"learning_rate": 0.19884775296113416,
|
| 207 |
+
"num_leaves": 18,
|
| 208 |
+
"max_depth": 11,
|
| 209 |
+
"min_child_samples": 88,
|
| 210 |
+
"feature_fraction": 0.9550517881195121,
|
| 211 |
+
"bagging_fraction": 0.7860517484953123,
|
| 212 |
+
"bagging_freq": 1,
|
| 213 |
+
"reg_alpha": 0.030296963787402084,
|
| 214 |
+
"reg_lambda": 0.8044239737357854,
|
| 215 |
+
"n_estimators": 1000,
|
| 216 |
+
"boosting_type": "gbdt",
|
| 217 |
+
"n_jobs": -1,
|
| 218 |
+
"random_state": 42,
|
| 219 |
+
"best_iteration": 284,
|
| 220 |
+
"val_auc": 0.9280613594287226
|
| 221 |
+
},
|
| 222 |
+
"SR-p53": {
|
| 223 |
+
"objective": "binary",
|
| 224 |
+
"metric": "auc",
|
| 225 |
+
"verbosity": -1,
|
| 226 |
+
"learning_rate": 0.1655358096626077,
|
| 227 |
+
"num_leaves": 85,
|
| 228 |
+
"max_depth": 0,
|
| 229 |
+
"min_child_samples": 87,
|
| 230 |
+
"feature_fraction": 0.8995086723685061,
|
| 231 |
+
"bagging_fraction": 0.9621288945710826,
|
| 232 |
+
"bagging_freq": 6,
|
| 233 |
+
"reg_alpha": 0.001186616750567751,
|
| 234 |
+
"reg_lambda": 0.00030749373152708483,
|
| 235 |
+
"n_estimators": 1000,
|
| 236 |
+
"boosting_type": "gbdt",
|
| 237 |
+
"n_jobs": -1,
|
| 238 |
+
"random_state": 42,
|
| 239 |
+
"best_iteration": 132,
|
| 240 |
+
"val_auc": 0.8038690476190476
|
| 241 |
+
}
|
| 242 |
+
}
|
checkpoints/stage2/NR-AR-LBD.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b877924ff923463aeb3d9347cc2d64077de0fd1e86c5ebceb1d6e85e7748cbba
|
| 3 |
+
size 71988
|
checkpoints/stage2/NR-AR.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:02e93424e24eb7a4fb704a191ad6d68c05145ab333d434c089e5bb98ff94fb9c
|
| 3 |
+
size 1836148
|
checkpoints/stage2/NR-AhR.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6681c8042c04542aed22e3c7e1bc756d78cbd90035a1717050202e4841a32c40
|
| 3 |
+
size 61716
|
checkpoints/stage2/NR-Aromatase.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:58d2b37062490cb61caf6b36ed765493fb17bd1805a2f9242e3c8dda058ed05f
|
| 3 |
+
size 41540
|
checkpoints/stage2/NR-ER-LBD.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:be0e41ee6f1e0463073f4eeb983bfb6f7bcedfba2174e18bab55118feec9ef0c
|
| 3 |
+
size 37092
|
checkpoints/stage2/NR-ER.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5aaaec49348aef97d610e5b0dd01ea41728474a5c8fc5cee3e961f0b97a28a1f
|
| 3 |
+
size 28884
|
checkpoints/stage2/NR-PPAR-gamma.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9f95337de1dd3dd937e3e15c74ff80b9cdb4d33b445446c1cd1e4044bcc3cb21
|
| 3 |
+
size 74644
|
checkpoints/stage2/SR-ARE.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9041bf8b713c4ec2793b67a41d2e302b9aa1f615625a2be53c12efdfdec1232c
|
| 3 |
+
size 52196
|
checkpoints/stage2/SR-ATAD5.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:aafc9da7d6dc67e322264cfdd34788b2aa509e91c9de58a3369ce366b89fbec6
|
| 3 |
+
size 485092
|
checkpoints/stage2/SR-HSE.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:763d559e337eb24b837d7a0cb9a530baa6b3ac8eccd1195e68ba39fbb078d3cb
|
| 3 |
+
size 30452
|
checkpoints/stage2/SR-MMP.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c8565440bb1a9284032924de804d10993efefd71ce00f92f4014cc11509f51a2
|
| 3 |
+
size 344964
|
checkpoints/stage2/SR-p53.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f446318ecd2ac93f685664bf68a0866e23a9275d53a32199f6f74557577c2152
|
| 3 |
+
size 30916
|
checkpoints/training_manifest.json
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"feature_config": {
|
| 3 |
+
"type": "ecfp",
|
| 4 |
+
"radius": 2,
|
| 5 |
+
"n_bits": 1024,
|
| 6 |
+
"use_counts": false,
|
| 7 |
+
"map4_dim": 1024,
|
| 8 |
+
"cache_dir": "./checkpoints/cache"
|
| 9 |
+
},
|
| 10 |
+
"target_names": [
|
| 11 |
+
"NR-AhR",
|
| 12 |
+
"NR-AR",
|
| 13 |
+
"NR-AR-LBD",
|
| 14 |
+
"NR-Aromatase",
|
| 15 |
+
"NR-ER",
|
| 16 |
+
"NR-ER-LBD",
|
| 17 |
+
"NR-PPAR-gamma",
|
| 18 |
+
"SR-ARE",
|
| 19 |
+
"SR-ATAD5",
|
| 20 |
+
"SR-HSE",
|
| 21 |
+
"SR-MMP",
|
| 22 |
+
"SR-p53"
|
| 23 |
+
],
|
| 24 |
+
"dataset": {
|
| 25 |
+
"name": "ml-jku/tox21"
|
| 26 |
+
},
|
| 27 |
+
"stage1": {
|
| 28 |
+
"model_dir": "checkpoints/stage1",
|
| 29 |
+
"metrics": "checkpoints/metrics_stage1.json"
|
| 30 |
+
},
|
| 31 |
+
"stage2": {
|
| 32 |
+
"enabled": true,
|
| 33 |
+
"model_dir": "checkpoints/stage2",
|
| 34 |
+
"metrics": "checkpoints/metrics_stage2.json"
|
| 35 |
+
},
|
| 36 |
+
"multitask": {
|
| 37 |
+
"enabled": true,
|
| 38 |
+
"prediction_source": "oof"
|
| 39 |
+
},
|
| 40 |
+
"seed": 42
|
| 41 |
+
}
|
config/config.json
CHANGED
|
@@ -12,8 +12,8 @@
|
|
| 12 |
"cache_dir": "./checkpoints/cache"
|
| 13 |
},
|
| 14 |
"training": {
|
| 15 |
-
"optuna_trials":
|
| 16 |
-
"
|
| 17 |
"early_stopping_rounds": 100,
|
| 18 |
"lightgbm_params": {
|
| 19 |
"objective": "binary",
|
|
|
|
| 12 |
"cache_dir": "./checkpoints/cache"
|
| 13 |
},
|
| 14 |
"training": {
|
| 15 |
+
"optuna_trials": 1000,
|
| 16 |
+
"n_estimators": [50, 500, 1000],
|
| 17 |
"early_stopping_rounds": 100,
|
| 18 |
"lightgbm_params": {
|
| 19 |
"objective": "binary",
|
requirements.txt
CHANGED
|
@@ -11,3 +11,4 @@ lightgbm
|
|
| 11 |
optuna
|
| 12 |
joblib
|
| 13 |
map4
|
|
|
|
|
|
| 11 |
optuna
|
| 12 |
joblib
|
| 13 |
map4
|
| 14 |
+
tqdm
|
src/lightgbm_trainer.py
CHANGED
|
@@ -11,6 +11,7 @@ import numpy as np
|
|
| 11 |
import optuna
|
| 12 |
import pandas as pd
|
| 13 |
from sklearn.metrics import roc_auc_score
|
|
|
|
| 14 |
|
| 15 |
from .constants import TARGET_NAMES
|
| 16 |
|
|
@@ -23,7 +24,29 @@ class TaskTrainingOutput:
|
|
| 23 |
best_params: Dict
|
| 24 |
|
| 25 |
|
| 26 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
params = dict(base_params)
|
| 28 |
params.update(
|
| 29 |
{
|
|
@@ -36,6 +59,7 @@ def _sample_hyperparams(trial: optuna.Trial, base_params: Dict) -> Dict:
|
|
| 36 |
"bagging_freq": trial.suggest_int("bagging_freq", 1, 10),
|
| 37 |
"reg_alpha": trial.suggest_float("reg_alpha", 1e-8, 10.0, log=True),
|
| 38 |
"reg_lambda": trial.suggest_float("reg_lambda", 1e-8, 10.0, log=True),
|
|
|
|
| 39 |
}
|
| 40 |
)
|
| 41 |
params.setdefault("objective", "binary")
|
|
@@ -52,7 +76,7 @@ def train_lightgbm_task(
|
|
| 52 |
X_val: np.ndarray,
|
| 53 |
y_val: np.ndarray,
|
| 54 |
base_params: Dict,
|
| 55 |
-
|
| 56 |
early_stopping_rounds: int,
|
| 57 |
n_trials: int,
|
| 58 |
seed: int,
|
|
@@ -61,8 +85,7 @@ def train_lightgbm_task(
|
|
| 61 |
return None
|
| 62 |
|
| 63 |
def objective(trial: optuna.Trial) -> float:
|
| 64 |
-
params = _sample_hyperparams(trial, base_params)
|
| 65 |
-
params["n_estimators"] = boosting_rounds
|
| 66 |
params["random_state"] = seed
|
| 67 |
model = lgb.LGBMClassifier(**params)
|
| 68 |
model.fit(
|
|
@@ -77,17 +100,15 @@ def train_lightgbm_task(
|
|
| 77 |
verbose=False,
|
| 78 |
)
|
| 79 |
],
|
| 80 |
-
verbose=False,
|
| 81 |
)
|
| 82 |
-
best_iter = getattr(model, "best_iteration_",
|
| 83 |
preds = model.predict_proba(X_val, num_iteration=best_iter)[:, 1]
|
| 84 |
return float(roc_auc_score(y_val, preds))
|
| 85 |
|
| 86 |
study = optuna.create_study(direction="maximize")
|
| 87 |
study.optimize(objective, n_trials=n_trials, show_progress_bar=False)
|
| 88 |
|
| 89 |
-
best_params = _sample_hyperparams(study.best_trial, base_params)
|
| 90 |
-
best_params["n_estimators"] = boosting_rounds
|
| 91 |
best_params["random_state"] = seed
|
| 92 |
|
| 93 |
final_model = lgb.LGBMClassifier(**best_params)
|
|
@@ -103,10 +124,9 @@ def train_lightgbm_task(
|
|
| 103 |
verbose=False,
|
| 104 |
)
|
| 105 |
],
|
| 106 |
-
verbose=False,
|
| 107 |
)
|
| 108 |
|
| 109 |
-
best_iteration = getattr(final_model, "best_iteration_",
|
| 110 |
val_preds = final_model.predict_proba(X_val, num_iteration=best_iteration)[:, 1]
|
| 111 |
val_auc = roc_auc_score(y_val, val_preds)
|
| 112 |
|
|
@@ -139,12 +159,13 @@ def train_stage_one_models(
|
|
| 139 |
training_cfg = config.get("training", {})
|
| 140 |
base_params = training_cfg.get("lightgbm_params", {})
|
| 141 |
n_trials = training_cfg.get("optuna_trials", 40)
|
| 142 |
-
|
| 143 |
early_stopping = training_cfg.get("early_stopping_rounds", 100)
|
| 144 |
seed = config.get("seed", 42)
|
| 145 |
|
|
|
|
| 146 |
n_train = len(train_df)
|
| 147 |
-
n_tasks = len(
|
| 148 |
|
| 149 |
train_preds = np.full((n_train, n_tasks), 0.5, dtype=np.float32)
|
| 150 |
val_preds = (
|
|
@@ -156,72 +177,74 @@ def train_stage_one_models(
|
|
| 156 |
metrics: Dict[str, Dict] = {}
|
| 157 |
params_dump: Dict[str, Dict] = {}
|
| 158 |
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
metrics[task_name] = {"status": "skipped", "reason": "training failed"}
|
| 193 |
-
continue
|
| 194 |
|
| 195 |
-
|
| 196 |
-
|
|
|
|
| 197 |
|
| 198 |
-
|
| 199 |
-
|
| 200 |
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
"best_iteration": best_iter,
|
| 204 |
-
"val_auc": task_result.val_auc,
|
| 205 |
-
}
|
| 206 |
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
val_features,
|
| 216 |
num_iteration=best_iter,
|
| 217 |
)[:, 1]
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
|
| 226 |
save_stage_metrics(metrics, checkpoint_dir / "metrics_stage1.json")
|
| 227 |
params_path = checkpoint_dir / "stage1_params.json"
|
|
|
|
| 11 |
import optuna
|
| 12 |
import pandas as pd
|
| 13 |
from sklearn.metrics import roc_auc_score
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
|
| 16 |
from .constants import TARGET_NAMES
|
| 17 |
|
|
|
|
| 24 |
best_params: Dict
|
| 25 |
|
| 26 |
|
| 27 |
+
def resolve_n_estimators(training_cfg: Dict) -> Sequence[int]:
|
| 28 |
+
"""Normalize the n_estimators config entry into a non-empty list of ints."""
|
| 29 |
+
if "n_estimators" in training_cfg:
|
| 30 |
+
raw_value = training_cfg["n_estimators"]
|
| 31 |
+
elif "boosting_rounds" in training_cfg:
|
| 32 |
+
raw_value = training_cfg["boosting_rounds"]
|
| 33 |
+
else:
|
| 34 |
+
raw_value = [50, 500, 1000]
|
| 35 |
+
|
| 36 |
+
if isinstance(raw_value, int):
|
| 37 |
+
choices = [int(raw_value)]
|
| 38 |
+
elif isinstance(raw_value, Sequence) and not isinstance(raw_value, (str, bytes)):
|
| 39 |
+
choices = [int(v) for v in raw_value]
|
| 40 |
+
else:
|
| 41 |
+
raise ValueError("training.n_estimators must be an int or a sequence of ints")
|
| 42 |
+
|
| 43 |
+
choices = [v for v in choices if v > 0]
|
| 44 |
+
if not choices:
|
| 45 |
+
raise ValueError("training.n_estimators must contain at least one positive value")
|
| 46 |
+
return choices
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _sample_hyperparams(trial: optuna.Trial, base_params: Dict, n_estimators_choices: Sequence[int]) -> Dict:
|
| 50 |
params = dict(base_params)
|
| 51 |
params.update(
|
| 52 |
{
|
|
|
|
| 59 |
"bagging_freq": trial.suggest_int("bagging_freq", 1, 10),
|
| 60 |
"reg_alpha": trial.suggest_float("reg_alpha", 1e-8, 10.0, log=True),
|
| 61 |
"reg_lambda": trial.suggest_float("reg_lambda", 1e-8, 10.0, log=True),
|
| 62 |
+
"n_estimators": trial.suggest_categorical("n_estimators", list(n_estimators_choices)),
|
| 63 |
}
|
| 64 |
)
|
| 65 |
params.setdefault("objective", "binary")
|
|
|
|
| 76 |
X_val: np.ndarray,
|
| 77 |
y_val: np.ndarray,
|
| 78 |
base_params: Dict,
|
| 79 |
+
n_estimators_choices: Sequence[int],
|
| 80 |
early_stopping_rounds: int,
|
| 81 |
n_trials: int,
|
| 82 |
seed: int,
|
|
|
|
| 85 |
return None
|
| 86 |
|
| 87 |
def objective(trial: optuna.Trial) -> float:
|
| 88 |
+
params = _sample_hyperparams(trial, base_params, n_estimators_choices)
|
|
|
|
| 89 |
params["random_state"] = seed
|
| 90 |
model = lgb.LGBMClassifier(**params)
|
| 91 |
model.fit(
|
|
|
|
| 100 |
verbose=False,
|
| 101 |
)
|
| 102 |
],
|
|
|
|
| 103 |
)
|
| 104 |
+
best_iter = getattr(model, "best_iteration_", params["n_estimators"])
|
| 105 |
preds = model.predict_proba(X_val, num_iteration=best_iter)[:, 1]
|
| 106 |
return float(roc_auc_score(y_val, preds))
|
| 107 |
|
| 108 |
study = optuna.create_study(direction="maximize")
|
| 109 |
study.optimize(objective, n_trials=n_trials, show_progress_bar=False)
|
| 110 |
|
| 111 |
+
best_params = _sample_hyperparams(study.best_trial, base_params, n_estimators_choices)
|
|
|
|
| 112 |
best_params["random_state"] = seed
|
| 113 |
|
| 114 |
final_model = lgb.LGBMClassifier(**best_params)
|
|
|
|
| 124 |
verbose=False,
|
| 125 |
)
|
| 126 |
],
|
|
|
|
| 127 |
)
|
| 128 |
|
| 129 |
+
best_iteration = getattr(final_model, "best_iteration_", best_params["n_estimators"])
|
| 130 |
val_preds = final_model.predict_proba(X_val, num_iteration=best_iteration)[:, 1]
|
| 131 |
val_auc = roc_auc_score(y_val, val_preds)
|
| 132 |
|
|
|
|
| 159 |
training_cfg = config.get("training", {})
|
| 160 |
base_params = training_cfg.get("lightgbm_params", {})
|
| 161 |
n_trials = training_cfg.get("optuna_trials", 40)
|
| 162 |
+
n_estimators_choices = resolve_n_estimators(training_cfg)
|
| 163 |
early_stopping = training_cfg.get("early_stopping_rounds", 100)
|
| 164 |
seed = config.get("seed", 42)
|
| 165 |
|
| 166 |
+
task_list = list(target_names)
|
| 167 |
n_train = len(train_df)
|
| 168 |
+
n_tasks = len(task_list)
|
| 169 |
|
| 170 |
train_preds = np.full((n_train, n_tasks), 0.5, dtype=np.float32)
|
| 171 |
val_preds = (
|
|
|
|
| 177 |
metrics: Dict[str, Dict] = {}
|
| 178 |
params_dump: Dict[str, Dict] = {}
|
| 179 |
|
| 180 |
+
with tqdm(task_list, desc="Stage 1", unit="task") as progress_bar:
|
| 181 |
+
for task_idx, task_name in enumerate(progress_bar):
|
| 182 |
+
progress_bar.set_postfix(task=task_name)
|
| 183 |
+
train_mask = train_df[task_name].notna().values
|
| 184 |
+
if val_df is None or val_features is None:
|
| 185 |
+
metrics[task_name] = {"status": "skipped", "reason": "missing validation split"}
|
| 186 |
+
continue
|
| 187 |
+
|
| 188 |
+
val_mask = val_df[task_name].notna().values
|
| 189 |
+
if train_mask.sum() < 2 or val_mask.sum() < 2:
|
| 190 |
+
metrics[task_name] = {"status": "skipped", "reason": "insufficient labeled data"}
|
| 191 |
+
continue
|
| 192 |
+
|
| 193 |
+
X_train_task = train_features[train_mask]
|
| 194 |
+
y_train_task = train_df.loc[train_mask, task_name].astype(float).values
|
| 195 |
+
X_val_task = val_features[val_mask]
|
| 196 |
+
y_val_task = val_df.loc[val_mask, task_name].astype(float).values
|
| 197 |
+
|
| 198 |
+
if len(np.unique(y_train_task)) < 2 or len(np.unique(y_val_task)) < 2:
|
| 199 |
+
metrics[task_name] = {"status": "skipped", "reason": "single-class labels"}
|
| 200 |
+
continue
|
| 201 |
+
|
| 202 |
+
task_result = train_lightgbm_task(
|
| 203 |
+
X_train_task,
|
| 204 |
+
y_train_task,
|
| 205 |
+
X_val_task,
|
| 206 |
+
y_val_task,
|
| 207 |
+
base_params=base_params,
|
| 208 |
+
n_estimators_choices=n_estimators_choices,
|
| 209 |
+
early_stopping_rounds=early_stopping,
|
| 210 |
+
n_trials=n_trials,
|
| 211 |
+
seed=seed,
|
| 212 |
+
)
|
|
|
|
|
|
|
| 213 |
|
| 214 |
+
if task_result is None:
|
| 215 |
+
metrics[task_name] = {"status": "skipped", "reason": "training failed"}
|
| 216 |
+
continue
|
| 217 |
|
| 218 |
+
model = task_result.model
|
| 219 |
+
best_iter = task_result.best_iteration
|
| 220 |
|
| 221 |
+
model_path = stage_dir / f"{task_name}.pkl"
|
| 222 |
+
joblib.dump(model, model_path)
|
|
|
|
|
|
|
|
|
|
| 223 |
|
| 224 |
+
params_dump[task_name] = {
|
| 225 |
+
**task_result.best_params,
|
| 226 |
+
"best_iteration": best_iter,
|
| 227 |
+
"val_auc": task_result.val_auc,
|
| 228 |
+
}
|
| 229 |
|
| 230 |
+
full_train_preds = model.predict_proba(
|
| 231 |
+
train_features,
|
|
|
|
| 232 |
num_iteration=best_iter,
|
| 233 |
)[:, 1]
|
| 234 |
+
train_preds[:, task_idx] = full_train_preds.astype(np.float32)
|
| 235 |
+
|
| 236 |
+
if val_preds is not None:
|
| 237 |
+
full_val_preds = model.predict_proba(
|
| 238 |
+
val_features,
|
| 239 |
+
num_iteration=best_iter,
|
| 240 |
+
)[:, 1]
|
| 241 |
+
val_preds[:, task_idx] = full_val_preds.astype(np.float32)
|
| 242 |
+
|
| 243 |
+
metrics[task_name] = {
|
| 244 |
+
"val_auc": task_result.val_auc,
|
| 245 |
+
"n_train_samples": int(train_mask.sum()),
|
| 246 |
+
"n_val_samples": int(val_mask.sum()),
|
| 247 |
+
}
|
| 248 |
|
| 249 |
save_stage_metrics(metrics, checkpoint_dir / "metrics_stage1.json")
|
| 250 |
params_path = checkpoint_dir / "stage1_params.json"
|
src/stage_two.py
CHANGED
|
@@ -6,10 +6,10 @@ from typing import Dict, Optional, Sequence
|
|
| 6 |
import joblib
|
| 7 |
import numpy as np
|
| 8 |
import pandas as pd
|
| 9 |
-
from
|
| 10 |
|
| 11 |
from .constants import TARGET_NAMES
|
| 12 |
-
from .lightgbm_trainer import save_stage_metrics, train_lightgbm_task
|
| 13 |
|
| 14 |
|
| 15 |
def _build_augmented_matrix(base_features: np.ndarray, prediction_matrix: np.ndarray, target_idx: int) -> np.ndarray:
|
|
@@ -32,75 +32,75 @@ def train_stage_two_models(
|
|
| 32 |
training_cfg = config.get("training", {})
|
| 33 |
base_params = training_cfg.get("lightgbm_params", {})
|
| 34 |
n_trials = training_cfg.get("optuna_trials", 40)
|
| 35 |
-
|
| 36 |
early_stopping = training_cfg.get("early_stopping_rounds", 100)
|
| 37 |
seed = config.get("seed", 42)
|
| 38 |
|
| 39 |
stage_dir = checkpoint_dir / "stage2"
|
| 40 |
stage_dir.mkdir(parents=True, exist_ok=True)
|
| 41 |
|
| 42 |
-
n_train = len(train_df)
|
| 43 |
-
n_val = len(val_df) if val_df is not None else 0
|
| 44 |
-
|
| 45 |
metrics: Dict[str, Dict] = {}
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
save_stage_metrics(metrics, checkpoint_dir / "metrics_stage2.json")
|
| 106 |
return {"metrics": metrics}
|
|
|
|
| 6 |
import joblib
|
| 7 |
import numpy as np
|
| 8 |
import pandas as pd
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
|
| 11 |
from .constants import TARGET_NAMES
|
| 12 |
+
from .lightgbm_trainer import resolve_n_estimators, save_stage_metrics, train_lightgbm_task
|
| 13 |
|
| 14 |
|
| 15 |
def _build_augmented_matrix(base_features: np.ndarray, prediction_matrix: np.ndarray, target_idx: int) -> np.ndarray:
|
|
|
|
| 32 |
training_cfg = config.get("training", {})
|
| 33 |
base_params = training_cfg.get("lightgbm_params", {})
|
| 34 |
n_trials = training_cfg.get("optuna_trials", 40)
|
| 35 |
+
n_estimators_choices = resolve_n_estimators(training_cfg)
|
| 36 |
early_stopping = training_cfg.get("early_stopping_rounds", 100)
|
| 37 |
seed = config.get("seed", 42)
|
| 38 |
|
| 39 |
stage_dir = checkpoint_dir / "stage2"
|
| 40 |
stage_dir.mkdir(parents=True, exist_ok=True)
|
| 41 |
|
|
|
|
|
|
|
|
|
|
| 42 |
metrics: Dict[str, Dict] = {}
|
| 43 |
+
task_list = list(target_names)
|
| 44 |
+
|
| 45 |
+
with tqdm(task_list, desc="Stage 2", unit="task") as progress_bar:
|
| 46 |
+
for task_idx, task_name in enumerate(progress_bar):
|
| 47 |
+
progress_bar.set_postfix(task=task_name)
|
| 48 |
+
mask = train_df[task_name].notna().values
|
| 49 |
+
if mask.sum() == 0:
|
| 50 |
+
metrics[task_name] = {"status": "skipped", "reason": "no labels"}
|
| 51 |
+
continue
|
| 52 |
+
|
| 53 |
+
augmented_train_matrix = _build_augmented_matrix(
|
| 54 |
+
train_features[mask],
|
| 55 |
+
stage1_train_preds[mask],
|
| 56 |
+
task_idx,
|
| 57 |
+
)
|
| 58 |
+
y_train = train_df.loc[mask, task_name].astype(float).values
|
| 59 |
+
|
| 60 |
+
if (
|
| 61 |
+
val_features is None
|
| 62 |
+
or val_df is None
|
| 63 |
+
or stage1_val_preds is None
|
| 64 |
+
or val_df[task_name].notna().sum() < 2
|
| 65 |
+
):
|
| 66 |
+
metrics[task_name] = {"status": "skipped", "reason": "missing validation data"}
|
| 67 |
+
continue
|
| 68 |
+
|
| 69 |
+
val_mask = val_df[task_name].notna().values
|
| 70 |
+
augmented_val_matrix = _build_augmented_matrix(
|
| 71 |
+
val_features[val_mask],
|
| 72 |
+
stage1_val_preds[val_mask],
|
| 73 |
+
task_idx,
|
| 74 |
+
)
|
| 75 |
+
y_val = val_df.loc[val_mask, task_name].astype(float).values
|
| 76 |
+
|
| 77 |
+
if len(np.unique(y_val)) < 2 or len(np.unique(y_train)) < 2:
|
| 78 |
+
metrics[task_name] = {"status": "skipped", "reason": "single-class labels"}
|
| 79 |
+
continue
|
| 80 |
+
|
| 81 |
+
task_result = train_lightgbm_task(
|
| 82 |
+
augmented_train_matrix,
|
| 83 |
+
y_train,
|
| 84 |
+
augmented_val_matrix,
|
| 85 |
+
y_val,
|
| 86 |
+
base_params=base_params,
|
| 87 |
+
n_estimators_choices=n_estimators_choices,
|
| 88 |
+
early_stopping_rounds=early_stopping,
|
| 89 |
+
n_trials=n_trials,
|
| 90 |
+
seed=seed,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
if task_result is None:
|
| 94 |
+
metrics[task_name] = {"status": "skipped", "reason": "training failed"}
|
| 95 |
+
continue
|
| 96 |
+
|
| 97 |
+
model_path = stage_dir / f"{task_name}.pkl"
|
| 98 |
+
joblib.dump(task_result.model, model_path)
|
| 99 |
+
|
| 100 |
+
metrics[task_name] = {
|
| 101 |
+
"val_auc": task_result.val_auc,
|
| 102 |
+
"best_iteration": int(task_result.best_iteration),
|
| 103 |
+
}
|
| 104 |
|
| 105 |
save_stage_metrics(metrics, checkpoint_dir / "metrics_stage2.json")
|
| 106 |
return {"metrics": metrics}
|