--- title: MultiTaskTox emoji: 📈 colorFrom: purple colorTo: blue sdk: docker pinned: false license: apache-2.0 short_description: MultiTaskTox is a two-stage Gradient Boosting workflow --- # MultiTaskTox – LightGBM Fingerprint Classifier for Tox21 MultiTaskTox is a two-stage Gradient Boosting workflow purpose-built for the [Tox21](https://huggingface.co/datasets/ml-jku/tox21) benchmark. It ingests molecular SMILES strings, converts them into high-dimensional fingerprints (ECFP or MAP4), and trains a set of LightGBM classifiers that leverage cross-task signal to improve toxicity prediction across all 12 Tox21 targets. ## Why MultiTaskTox? - **Deterministic preprocessing** – every SMILES string is standardized through RDKit before fingerprint generation, ensuring training and inference behave identically. - **Optuna-tuned per-task boosters** – each toxicity endpoint receives its own LightGBM classifier, tuned directly on the provided train/validation splits. - **Multitask enhancement** – stage two augments the fingerprint vector with the predictions of the other tasks, capturing label correlations without building a fully joint model. - **Leaderboard-ready interface** – `train.py` produces checkpoints and metadata under `checkpoints/`, while `predict.py` exposes the required `predict(smiles_list)` signature. ### What `train.py` does 1. Loads the predefined `train` and `validation` splits from the Tox21 dataset. 2. Standardizes SMILES and builds fingerprints using `src/features.py`. 3. For each target: - Runs Optuna to find the best LightGBM hyperparameters using the validation split as the evaluation set. - Fits the classifier (`stage1`) and stores the model as `checkpoints/stage1/.pkl`. 4. Generates prediction matrices for both splits. 5. If multitask mode is enabled (`config["multitask"]["enabled"]`), creates augmented features (fingerprint + other-task predictions) and trains stage-two boosters saved under `checkpoints/stage2/`. 6. Writes metrics (`metrics_stage1.json`, `metrics_stage2.json`) and a manifest (`training_manifest.json`) describing the experiment. ## Inference `predict.py` exposes: ```python from predict import predict smiles = ["CCO", "c1ccccc1", "CC(=O)O"] results = predict(smiles) ``` The function: 1. Loads the training manifest to know which fingerprint type and checkpoints to use. 2. Standardizes and fingerprints the SMILES on the fly. 3. Runs stage-one LightGBM classifiers to obtain probabilistic predictions. 4. If stage-two models exist, augments the features with cross-task predictions and runs the multitask models. 5. Returns `{smiles: {target_name: probability}}` with values in `[0, 1]`. Invalid SMILES fall back to `0.5`. ## Configuration Overview (`config/config.json`) ```json { "seed": 42, "dataset": {"name": "ml-jku/tox21"}, "features": { "type": "ecfp", "radius": 2, "n_bits": 1024, "map4_dim": 1024, "cache_dir": "./checkpoints/cache" }, "training": { "optuna_trials": 40, "n_estimators": [50, 500, 1000], "early_stopping_rounds": 100, "lightgbm_params": { "objective": "binary", "metric": "auc", "verbosity": -1 } }, "multitask": {"enabled": true}, "output": {"checkpoint_dir": "./checkpoints"} } ``` - Switch `features.type` to `"map4"` to use MAP4 fingerprints (installed by default). - Disable multitask behavior by setting `"multitask": {"enabled": false}`. - Increase `optuna_trials` for a more exhaustive search if compute allows. - Set `training.n_estimators` to either a single integer or a list of candidate values (default `[50, 500, 1000]`) to control the Optuna search space for the `n_estimators` hyperparameter. ## Repository Layout - `train.py` – orchestrates the full training workflow (feature generation, Optuna tuning, stage-one and stage-two models). - `predict.py` – leaderboard-friendly inference function that loads the checkpoints generated by `train.py`. - `src/preprocess.py` – dataset loading and SMILES standardization helpers. - `src/features.py` – fingerprint computation with disk caching. - `src/lightgbm_trainer.py` – LightGBM + Optuna utilities for stage-one training. - `src/stage_two.py` – multitask feature augmentation and model training. - `src/constants.py`, `src/seed.py` – shared utilities. - `docs/proposed_lightgbm_framework.md` – detailed design notes for the workflow. - `checkpoints/` – default output directory containing models, metrics, caches, and the training manifest used at inference time.