MultiTaskTox / README.md
Maximilian Schuh
Added new files and updated learning
759324e
---
title: MultiTaskTox
emoji: πŸ“ˆ
colorFrom: purple
colorTo: blue
sdk: docker
pinned: false
license: apache-2.0
short_description: MultiTaskTox is a two-stage Gradient Boosting workflow
---
# MultiTaskTox – LightGBM Fingerprint Classifier for Tox21
MultiTaskTox is a two-stage Gradient Boosting workflow purpose-built for the [Tox21](https://huggingface.co/datasets/ml-jku/tox21) benchmark. It ingests molecular SMILES strings, converts them into high-dimensional fingerprints (ECFP or MAP4), and trains a set of LightGBM classifiers that leverage cross-task signal to improve toxicity prediction across all 12 Tox21 targets.
## Why MultiTaskTox?
- **Deterministic preprocessing** – every SMILES string is standardized through RDKit before fingerprint generation, ensuring training and inference behave identically.
- **Optuna-tuned per-task boosters** – each toxicity endpoint receives its own LightGBM classifier, tuned directly on the provided train/validation splits.
- **Multitask enhancement** – stage two augments the fingerprint vector with the predictions of the other tasks, capturing label correlations without building a fully joint model.
- **Leaderboard-ready interface** – `train.py` produces checkpoints and metadata under `checkpoints/`, while `predict.py` exposes the required `predict(smiles_list)` signature.
### What `train.py` does
1. Loads the predefined `train` and `validation` splits from the Tox21 dataset.
2. Standardizes SMILES and builds fingerprints using `src/features.py`.
3. For each target:
- Runs Optuna to find the best LightGBM hyperparameters using the validation split as the evaluation set.
- Fits the classifier (`stage1`) and stores the model as `checkpoints/stage1/<target>.pkl`.
4. Generates prediction matrices for both splits.
5. If multitask mode is enabled (`config["multitask"]["enabled"]`), creates augmented features (fingerprint + other-task predictions) and trains stage-two boosters saved under `checkpoints/stage2/`.
6. Writes metrics (`metrics_stage1.json`, `metrics_stage2.json`) and a manifest (`training_manifest.json`) describing the experiment.
## Inference
`predict.py` exposes:
```python
from predict import predict
smiles = ["CCO", "c1ccccc1", "CC(=O)O"]
results = predict(smiles)
```
The function:
1. Loads the training manifest to know which fingerprint type and checkpoints to use.
2. Standardizes and fingerprints the SMILES on the fly.
3. Runs stage-one LightGBM classifiers to obtain probabilistic predictions.
4. If stage-two models exist, augments the features with cross-task predictions and runs the multitask models.
5. Returns `{smiles: {target_name: probability}}` with values in `[0, 1]`. Invalid SMILES fall back to `0.5`.
## Configuration Overview (`config/config.json`)
```json
{
"seed": 42,
"dataset": {"name": "ml-jku/tox21"},
"features": {
"type": "ecfp",
"radius": 2,
"n_bits": 1024,
"map4_dim": 1024,
"cache_dir": "./checkpoints/cache"
},
"training": {
"optuna_trials": 40,
"n_estimators": [50, 500, 1000],
"early_stopping_rounds": 100,
"lightgbm_params": {
"objective": "binary",
"metric": "auc",
"verbosity": -1
}
},
"multitask": {"enabled": true},
"output": {"checkpoint_dir": "./checkpoints"}
}
```
- Switch `features.type` to `"map4"` to use MAP4 fingerprints (installed by default).
- Disable multitask behavior by setting `"multitask": {"enabled": false}`.
- Increase `optuna_trials` for a more exhaustive search if compute allows.
- Set `training.n_estimators` to either a single integer or a list of candidate values (default `[50, 500, 1000]`) to control the Optuna search space for the `n_estimators` hyperparameter.
## Repository Layout
- `train.py` – orchestrates the full training workflow (feature generation, Optuna tuning, stage-one and stage-two models).
- `predict.py` – leaderboard-friendly inference function that loads the checkpoints generated by `train.py`.
- `src/preprocess.py` – dataset loading and SMILES standardization helpers.
- `src/features.py` – fingerprint computation with disk caching.
- `src/lightgbm_trainer.py` – LightGBM + Optuna utilities for stage-one training.
- `src/stage_two.py` – multitask feature augmentation and model training.
- `src/constants.py`, `src/seed.py` – shared utilities.
- `docs/proposed_lightgbm_framework.md` – detailed design notes for the workflow.
- `checkpoints/` – default output directory containing models, metrics, caches, and the training manifest used at inference time.