Spaces:
Sleeping
Sleeping
metadata
title: MultiTaskTox
emoji: π
colorFrom: purple
colorTo: blue
sdk: docker
pinned: false
license: apache-2.0
short_description: MultiTaskTox is a two-stage Gradient Boosting workflow
MultiTaskTox β LightGBM Fingerprint Classifier for Tox21
MultiTaskTox is a two-stage Gradient Boosting workflow purpose-built for the Tox21 benchmark. It ingests molecular SMILES strings, converts them into high-dimensional fingerprints (ECFP or MAP4), and trains a set of LightGBM classifiers that leverage cross-task signal to improve toxicity prediction across all 12 Tox21 targets.
Why MultiTaskTox?
- Deterministic preprocessing β every SMILES string is standardized through RDKit before fingerprint generation, ensuring training and inference behave identically.
- Optuna-tuned per-task boosters β each toxicity endpoint receives its own LightGBM classifier, tuned directly on the provided train/validation splits.
- Multitask enhancement β stage two augments the fingerprint vector with the predictions of the other tasks, capturing label correlations without building a fully joint model.
- Leaderboard-ready interface β
train.pyproduces checkpoints and metadata undercheckpoints/, whilepredict.pyexposes the requiredpredict(smiles_list)signature.
What train.py does
- Loads the predefined
trainandvalidationsplits from the Tox21 dataset. - Standardizes SMILES and builds fingerprints using
src/features.py. - For each target:
- Runs Optuna to find the best LightGBM hyperparameters using the validation split as the evaluation set.
- Fits the classifier (
stage1) and stores the model ascheckpoints/stage1/<target>.pkl.
- Generates prediction matrices for both splits.
- If multitask mode is enabled (
config["multitask"]["enabled"]), creates augmented features (fingerprint + other-task predictions) and trains stage-two boosters saved undercheckpoints/stage2/. - Writes metrics (
metrics_stage1.json,metrics_stage2.json) and a manifest (training_manifest.json) describing the experiment.
Inference
predict.py exposes:
from predict import predict
smiles = ["CCO", "c1ccccc1", "CC(=O)O"]
results = predict(smiles)
The function:
- Loads the training manifest to know which fingerprint type and checkpoints to use.
- Standardizes and fingerprints the SMILES on the fly.
- Runs stage-one LightGBM classifiers to obtain probabilistic predictions.
- If stage-two models exist, augments the features with cross-task predictions and runs the multitask models.
- Returns
{smiles: {target_name: probability}}with values in[0, 1]. Invalid SMILES fall back to0.5.
Configuration Overview (config/config.json)
{
"seed": 42,
"dataset": {"name": "ml-jku/tox21"},
"features": {
"type": "ecfp",
"radius": 2,
"n_bits": 1024,
"map4_dim": 1024,
"cache_dir": "./checkpoints/cache"
},
"training": {
"optuna_trials": 40,
"n_estimators": [50, 500, 1000],
"early_stopping_rounds": 100,
"lightgbm_params": {
"objective": "binary",
"metric": "auc",
"verbosity": -1
}
},
"multitask": {"enabled": true},
"output": {"checkpoint_dir": "./checkpoints"}
}
- Switch
features.typeto"map4"to use MAP4 fingerprints (installed by default). - Disable multitask behavior by setting
"multitask": {"enabled": false}. - Increase
optuna_trialsfor a more exhaustive search if compute allows. - Set
training.n_estimatorsto either a single integer or a list of candidate values (default[50, 500, 1000]) to control the Optuna search space for then_estimatorshyperparameter.
Repository Layout
train.pyβ orchestrates the full training workflow (feature generation, Optuna tuning, stage-one and stage-two models).predict.pyβ leaderboard-friendly inference function that loads the checkpoints generated bytrain.py.src/preprocess.pyβ dataset loading and SMILES standardization helpers.src/features.pyβ fingerprint computation with disk caching.src/lightgbm_trainer.pyβ LightGBM + Optuna utilities for stage-one training.src/stage_two.pyβ multitask feature augmentation and model training.src/constants.py,src/seed.pyβ shared utilities.docs/proposed_lightgbm_framework.mdβ detailed design notes for the workflow.checkpoints/β default output directory containing models, metrics, caches, and the training manifest used at inference time.