File size: 4,560 Bytes
e1aa0ed
 
 
 
 
 
 
 
 
 
 
94b1553
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
759324e
94b1553
 
 
 
 
 
 
 
 
 
 
 
 
 
 
759324e
94b1553
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
---
title: MultiTaskTox
emoji: πŸ“ˆ
colorFrom: purple
colorTo: blue
sdk: docker
pinned: false
license: apache-2.0
short_description: MultiTaskTox is a two-stage Gradient Boosting workflow
---

# MultiTaskTox – LightGBM Fingerprint Classifier for Tox21

MultiTaskTox is a two-stage Gradient Boosting workflow purpose-built for the [Tox21](https://huggingface.co/datasets/ml-jku/tox21) benchmark. It ingests molecular SMILES strings, converts them into high-dimensional fingerprints (ECFP or MAP4), and trains a set of LightGBM classifiers that leverage cross-task signal to improve toxicity prediction across all 12 Tox21 targets.

## Why MultiTaskTox?

- **Deterministic preprocessing** – every SMILES string is standardized through RDKit before fingerprint generation, ensuring training and inference behave identically.
- **Optuna-tuned per-task boosters** – each toxicity endpoint receives its own LightGBM classifier, tuned directly on the provided train/validation splits.
- **Multitask enhancement** – stage two augments the fingerprint vector with the predictions of the other tasks, capturing label correlations without building a fully joint model.
- **Leaderboard-ready interface** – `train.py` produces checkpoints and metadata under `checkpoints/`, while `predict.py` exposes the required `predict(smiles_list)` signature.

### What `train.py` does

1. Loads the predefined `train` and `validation` splits from the Tox21 dataset.
2. Standardizes SMILES and builds fingerprints using `src/features.py`.
3. For each target:
   - Runs Optuna to find the best LightGBM hyperparameters using the validation split as the evaluation set.
   - Fits the classifier (`stage1`) and stores the model as `checkpoints/stage1/<target>.pkl`.
4. Generates prediction matrices for both splits.
5. If multitask mode is enabled (`config["multitask"]["enabled"]`), creates augmented features (fingerprint + other-task predictions) and trains stage-two boosters saved under `checkpoints/stage2/`.
6. Writes metrics (`metrics_stage1.json`, `metrics_stage2.json`) and a manifest (`training_manifest.json`) describing the experiment.

## Inference

`predict.py` exposes:

```python
from predict import predict

smiles = ["CCO", "c1ccccc1", "CC(=O)O"]
results = predict(smiles)
```

The function:
1. Loads the training manifest to know which fingerprint type and checkpoints to use.
2. Standardizes and fingerprints the SMILES on the fly.
3. Runs stage-one LightGBM classifiers to obtain probabilistic predictions.
4. If stage-two models exist, augments the features with cross-task predictions and runs the multitask models.
5. Returns `{smiles: {target_name: probability}}` with values in `[0, 1]`. Invalid SMILES fall back to `0.5`.

## Configuration Overview (`config/config.json`)

```json
{
  "seed": 42,
  "dataset": {"name": "ml-jku/tox21"},
  "features": {
    "type": "ecfp",
    "radius": 2,
    "n_bits": 1024,
    "map4_dim": 1024,
    "cache_dir": "./checkpoints/cache"
  },
  "training": {
    "optuna_trials": 40,
    "n_estimators": [50, 500, 1000],
    "early_stopping_rounds": 100,
    "lightgbm_params": {
      "objective": "binary",
      "metric": "auc",
      "verbosity": -1
    }
  },
  "multitask": {"enabled": true},
  "output": {"checkpoint_dir": "./checkpoints"}
}
```

- Switch `features.type` to `"map4"` to use MAP4 fingerprints (installed by default).
- Disable multitask behavior by setting `"multitask": {"enabled": false}`.
- Increase `optuna_trials` for a more exhaustive search if compute allows.
- Set `training.n_estimators` to either a single integer or a list of candidate values (default `[50, 500, 1000]`) to control the Optuna search space for the `n_estimators` hyperparameter.

## Repository Layout

- `train.py` – orchestrates the full training workflow (feature generation, Optuna tuning, stage-one and stage-two models).
- `predict.py` – leaderboard-friendly inference function that loads the checkpoints generated by `train.py`.
- `src/preprocess.py` – dataset loading and SMILES standardization helpers.
- `src/features.py` – fingerprint computation with disk caching.
- `src/lightgbm_trainer.py` – LightGBM + Optuna utilities for stage-one training.
- `src/stage_two.py` – multitask feature augmentation and model training.
- `src/constants.py`, `src/seed.py` – shared utilities.
- `docs/proposed_lightgbm_framework.md` – detailed design notes for the workflow.
- `checkpoints/` – default output directory containing models, metrics, caches, and the training manifest used at inference time.