MultiTaskTox / README.md
Maximilian Schuh
Added new files and updated learning
759324e
metadata
title: MultiTaskTox
emoji: πŸ“ˆ
colorFrom: purple
colorTo: blue
sdk: docker
pinned: false
license: apache-2.0
short_description: MultiTaskTox is a two-stage Gradient Boosting workflow

MultiTaskTox – LightGBM Fingerprint Classifier for Tox21

MultiTaskTox is a two-stage Gradient Boosting workflow purpose-built for the Tox21 benchmark. It ingests molecular SMILES strings, converts them into high-dimensional fingerprints (ECFP or MAP4), and trains a set of LightGBM classifiers that leverage cross-task signal to improve toxicity prediction across all 12 Tox21 targets.

Why MultiTaskTox?

  • Deterministic preprocessing – every SMILES string is standardized through RDKit before fingerprint generation, ensuring training and inference behave identically.
  • Optuna-tuned per-task boosters – each toxicity endpoint receives its own LightGBM classifier, tuned directly on the provided train/validation splits.
  • Multitask enhancement – stage two augments the fingerprint vector with the predictions of the other tasks, capturing label correlations without building a fully joint model.
  • Leaderboard-ready interface – train.py produces checkpoints and metadata under checkpoints/, while predict.py exposes the required predict(smiles_list) signature.

What train.py does

  1. Loads the predefined train and validation splits from the Tox21 dataset.
  2. Standardizes SMILES and builds fingerprints using src/features.py.
  3. For each target:
    • Runs Optuna to find the best LightGBM hyperparameters using the validation split as the evaluation set.
    • Fits the classifier (stage1) and stores the model as checkpoints/stage1/<target>.pkl.
  4. Generates prediction matrices for both splits.
  5. If multitask mode is enabled (config["multitask"]["enabled"]), creates augmented features (fingerprint + other-task predictions) and trains stage-two boosters saved under checkpoints/stage2/.
  6. Writes metrics (metrics_stage1.json, metrics_stage2.json) and a manifest (training_manifest.json) describing the experiment.

Inference

predict.py exposes:

from predict import predict

smiles = ["CCO", "c1ccccc1", "CC(=O)O"]
results = predict(smiles)

The function:

  1. Loads the training manifest to know which fingerprint type and checkpoints to use.
  2. Standardizes and fingerprints the SMILES on the fly.
  3. Runs stage-one LightGBM classifiers to obtain probabilistic predictions.
  4. If stage-two models exist, augments the features with cross-task predictions and runs the multitask models.
  5. Returns {smiles: {target_name: probability}} with values in [0, 1]. Invalid SMILES fall back to 0.5.

Configuration Overview (config/config.json)

{
  "seed": 42,
  "dataset": {"name": "ml-jku/tox21"},
  "features": {
    "type": "ecfp",
    "radius": 2,
    "n_bits": 1024,
    "map4_dim": 1024,
    "cache_dir": "./checkpoints/cache"
  },
  "training": {
    "optuna_trials": 40,
    "n_estimators": [50, 500, 1000],
    "early_stopping_rounds": 100,
    "lightgbm_params": {
      "objective": "binary",
      "metric": "auc",
      "verbosity": -1
    }
  },
  "multitask": {"enabled": true},
  "output": {"checkpoint_dir": "./checkpoints"}
}
  • Switch features.type to "map4" to use MAP4 fingerprints (installed by default).
  • Disable multitask behavior by setting "multitask": {"enabled": false}.
  • Increase optuna_trials for a more exhaustive search if compute allows.
  • Set training.n_estimators to either a single integer or a list of candidate values (default [50, 500, 1000]) to control the Optuna search space for the n_estimators hyperparameter.

Repository Layout

  • train.py – orchestrates the full training workflow (feature generation, Optuna tuning, stage-one and stage-two models).
  • predict.py – leaderboard-friendly inference function that loads the checkpoints generated by train.py.
  • src/preprocess.py – dataset loading and SMILES standardization helpers.
  • src/features.py – fingerprint computation with disk caching.
  • src/lightgbm_trainer.py – LightGBM + Optuna utilities for stage-one training.
  • src/stage_two.py – multitask feature augmentation and model training.
  • src/constants.py, src/seed.py – shared utilities.
  • docs/proposed_lightgbm_framework.md – detailed design notes for the workflow.
  • checkpoints/ – default output directory containing models, metrics, caches, and the training manifest used at inference time.