Spaces:
Sleeping
Sleeping
File size: 4,560 Bytes
e1aa0ed 94b1553 759324e 94b1553 759324e 94b1553 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
---
title: MultiTaskTox
emoji: π
colorFrom: purple
colorTo: blue
sdk: docker
pinned: false
license: apache-2.0
short_description: MultiTaskTox is a two-stage Gradient Boosting workflow
---
# MultiTaskTox β LightGBM Fingerprint Classifier for Tox21
MultiTaskTox is a two-stage Gradient Boosting workflow purpose-built for the [Tox21](https://huggingface.co/datasets/ml-jku/tox21) benchmark. It ingests molecular SMILES strings, converts them into high-dimensional fingerprints (ECFP or MAP4), and trains a set of LightGBM classifiers that leverage cross-task signal to improve toxicity prediction across all 12 Tox21 targets.
## Why MultiTaskTox?
- **Deterministic preprocessing** β every SMILES string is standardized through RDKit before fingerprint generation, ensuring training and inference behave identically.
- **Optuna-tuned per-task boosters** β each toxicity endpoint receives its own LightGBM classifier, tuned directly on the provided train/validation splits.
- **Multitask enhancement** β stage two augments the fingerprint vector with the predictions of the other tasks, capturing label correlations without building a fully joint model.
- **Leaderboard-ready interface** β `train.py` produces checkpoints and metadata under `checkpoints/`, while `predict.py` exposes the required `predict(smiles_list)` signature.
### What `train.py` does
1. Loads the predefined `train` and `validation` splits from the Tox21 dataset.
2. Standardizes SMILES and builds fingerprints using `src/features.py`.
3. For each target:
- Runs Optuna to find the best LightGBM hyperparameters using the validation split as the evaluation set.
- Fits the classifier (`stage1`) and stores the model as `checkpoints/stage1/<target>.pkl`.
4. Generates prediction matrices for both splits.
5. If multitask mode is enabled (`config["multitask"]["enabled"]`), creates augmented features (fingerprint + other-task predictions) and trains stage-two boosters saved under `checkpoints/stage2/`.
6. Writes metrics (`metrics_stage1.json`, `metrics_stage2.json`) and a manifest (`training_manifest.json`) describing the experiment.
## Inference
`predict.py` exposes:
```python
from predict import predict
smiles = ["CCO", "c1ccccc1", "CC(=O)O"]
results = predict(smiles)
```
The function:
1. Loads the training manifest to know which fingerprint type and checkpoints to use.
2. Standardizes and fingerprints the SMILES on the fly.
3. Runs stage-one LightGBM classifiers to obtain probabilistic predictions.
4. If stage-two models exist, augments the features with cross-task predictions and runs the multitask models.
5. Returns `{smiles: {target_name: probability}}` with values in `[0, 1]`. Invalid SMILES fall back to `0.5`.
## Configuration Overview (`config/config.json`)
```json
{
"seed": 42,
"dataset": {"name": "ml-jku/tox21"},
"features": {
"type": "ecfp",
"radius": 2,
"n_bits": 1024,
"map4_dim": 1024,
"cache_dir": "./checkpoints/cache"
},
"training": {
"optuna_trials": 40,
"n_estimators": [50, 500, 1000],
"early_stopping_rounds": 100,
"lightgbm_params": {
"objective": "binary",
"metric": "auc",
"verbosity": -1
}
},
"multitask": {"enabled": true},
"output": {"checkpoint_dir": "./checkpoints"}
}
```
- Switch `features.type` to `"map4"` to use MAP4 fingerprints (installed by default).
- Disable multitask behavior by setting `"multitask": {"enabled": false}`.
- Increase `optuna_trials` for a more exhaustive search if compute allows.
- Set `training.n_estimators` to either a single integer or a list of candidate values (default `[50, 500, 1000]`) to control the Optuna search space for the `n_estimators` hyperparameter.
## Repository Layout
- `train.py` β orchestrates the full training workflow (feature generation, Optuna tuning, stage-one and stage-two models).
- `predict.py` β leaderboard-friendly inference function that loads the checkpoints generated by `train.py`.
- `src/preprocess.py` β dataset loading and SMILES standardization helpers.
- `src/features.py` β fingerprint computation with disk caching.
- `src/lightgbm_trainer.py` β LightGBM + Optuna utilities for stage-one training.
- `src/stage_two.py` β multitask feature augmentation and model training.
- `src/constants.py`, `src/seed.py` β shared utilities.
- `docs/proposed_lightgbm_framework.md` β detailed design notes for the workflow.
- `checkpoints/` β default output directory containing models, metrics, caches, and the training manifest used at inference time.
|