Spaces:
Sleeping
Sleeping
| title: MultiTaskTox | |
| emoji: π | |
| colorFrom: purple | |
| colorTo: blue | |
| sdk: docker | |
| pinned: false | |
| license: apache-2.0 | |
| short_description: MultiTaskTox is a two-stage Gradient Boosting workflow | |
| # MultiTaskTox β LightGBM Fingerprint Classifier for Tox21 | |
| MultiTaskTox is a two-stage Gradient Boosting workflow purpose-built for the [Tox21](https://huggingface.co/datasets/ml-jku/tox21) benchmark. It ingests molecular SMILES strings, converts them into high-dimensional fingerprints (ECFP or MAP4), and trains a set of LightGBM classifiers that leverage cross-task signal to improve toxicity prediction across all 12 Tox21 targets. | |
| ## Why MultiTaskTox? | |
| - **Deterministic preprocessing** β every SMILES string is standardized through RDKit before fingerprint generation, ensuring training and inference behave identically. | |
| - **Optuna-tuned per-task boosters** β each toxicity endpoint receives its own LightGBM classifier, tuned directly on the provided train/validation splits. | |
| - **Multitask enhancement** β stage two augments the fingerprint vector with the predictions of the other tasks, capturing label correlations without building a fully joint model. | |
| - **Leaderboard-ready interface** β `train.py` produces checkpoints and metadata under `checkpoints/`, while `predict.py` exposes the required `predict(smiles_list)` signature. | |
| ### What `train.py` does | |
| 1. Loads the predefined `train` and `validation` splits from the Tox21 dataset. | |
| 2. Standardizes SMILES and builds fingerprints using `src/features.py`. | |
| 3. For each target: | |
| - Runs Optuna to find the best LightGBM hyperparameters using the validation split as the evaluation set. | |
| - Fits the classifier (`stage1`) and stores the model as `checkpoints/stage1/<target>.pkl`. | |
| 4. Generates prediction matrices for both splits. | |
| 5. If multitask mode is enabled (`config["multitask"]["enabled"]`), creates augmented features (fingerprint + other-task predictions) and trains stage-two boosters saved under `checkpoints/stage2/`. | |
| 6. Writes metrics (`metrics_stage1.json`, `metrics_stage2.json`) and a manifest (`training_manifest.json`) describing the experiment. | |
| ## Inference | |
| `predict.py` exposes: | |
| ```python | |
| from predict import predict | |
| smiles = ["CCO", "c1ccccc1", "CC(=O)O"] | |
| results = predict(smiles) | |
| ``` | |
| The function: | |
| 1. Loads the training manifest to know which fingerprint type and checkpoints to use. | |
| 2. Standardizes and fingerprints the SMILES on the fly. | |
| 3. Runs stage-one LightGBM classifiers to obtain probabilistic predictions. | |
| 4. If stage-two models exist, augments the features with cross-task predictions and runs the multitask models. | |
| 5. Returns `{smiles: {target_name: probability}}` with values in `[0, 1]`. Invalid SMILES fall back to `0.5`. | |
| ## Configuration Overview (`config/config.json`) | |
| ```json | |
| { | |
| "seed": 42, | |
| "dataset": {"name": "ml-jku/tox21"}, | |
| "features": { | |
| "type": "ecfp", | |
| "radius": 2, | |
| "n_bits": 1024, | |
| "map4_dim": 1024, | |
| "cache_dir": "./checkpoints/cache" | |
| }, | |
| "training": { | |
| "optuna_trials": 40, | |
| "n_estimators": [50, 500, 1000], | |
| "early_stopping_rounds": 100, | |
| "lightgbm_params": { | |
| "objective": "binary", | |
| "metric": "auc", | |
| "verbosity": -1 | |
| } | |
| }, | |
| "multitask": {"enabled": true}, | |
| "output": {"checkpoint_dir": "./checkpoints"} | |
| } | |
| ``` | |
| - Switch `features.type` to `"map4"` to use MAP4 fingerprints (installed by default). | |
| - Disable multitask behavior by setting `"multitask": {"enabled": false}`. | |
| - Increase `optuna_trials` for a more exhaustive search if compute allows. | |
| - Set `training.n_estimators` to either a single integer or a list of candidate values (default `[50, 500, 1000]`) to control the Optuna search space for the `n_estimators` hyperparameter. | |
| ## Repository Layout | |
| - `train.py` β orchestrates the full training workflow (feature generation, Optuna tuning, stage-one and stage-two models). | |
| - `predict.py` β leaderboard-friendly inference function that loads the checkpoints generated by `train.py`. | |
| - `src/preprocess.py` β dataset loading and SMILES standardization helpers. | |
| - `src/features.py` β fingerprint computation with disk caching. | |
| - `src/lightgbm_trainer.py` β LightGBM + Optuna utilities for stage-one training. | |
| - `src/stage_two.py` β multitask feature augmentation and model training. | |
| - `src/constants.py`, `src/seed.py` β shared utilities. | |
| - `docs/proposed_lightgbm_framework.md` β detailed design notes for the workflow. | |
| - `checkpoints/` β default output directory containing models, metrics, caches, and the training manifest used at inference time. | |