Spaces:
Sleeping
Sleeping
Delete docs
Browse files
docs/proposed_lightgbm_framework.md
DELETED
|
@@ -1,203 +0,0 @@
|
|
| 1 |
-
# LightGBM-Based Multitask Workflow for Tox21
|
| 2 |
-
|
| 3 |
-
This document proposes a stepwise plan to replace the current GIN baseline (`train.py`, `predict.py`, `src/`) with a Gradient Boosting pipeline that remains compatible with the leaderboard I/O contract. Each phase can be validated independently before moving to the next, ensuring we have working training and inference artifacts at all times.
|
| 4 |
-
|
| 5 |
-
---
|
| 6 |
-
|
| 7 |
-
## 0. Repository Integration Checklist
|
| 8 |
-
- **Entry-points stay the same.** `train.py` must continue to train from `config/config.json` and drop an inference-ready artifact into `checkpoints/`. `predict.py` must keep the `predict(smiles_list)` signature and return the nested `{smiles: {target: score}}` mapping.
|
| 9 |
-
- **New modules.** Introduce `src/features.py` (fingerprints & caching), `src/lightgbm_trainer.py` (shared utilities for training/evaluation), and `src/stage_two.py` (cross-task augmentation logic). Keep `src/preprocess.py` for SMILES standardization + RDKit `Mol` construction so inference stays aligned with training.
|
| 10 |
-
- **Dependencies.** Add `lightgbm`, `optuna`, `rdkit-pypi`, and optionally `map4` or `map4` reference code to `requirements.txt`. Verify any native dependencies are supported by the Spaces environment.
|
| 11 |
-
- **Artifacts.** Store per-task boosters as `checkpoints/stage1_{task}.txt` and `checkpoints/stage2_{task}.txt` (LightGBM text dumps). Derived predictions (e.g., stage-1 OOF matrices) should live under `checkpoints/cache/` or `/tmp` during training, but inference must rely only on checkpoint files generated by `train.py`.
|
| 12 |
-
|
| 13 |
-
---
|
| 14 |
-
|
| 15 |
-
## 1. Phase 1 — Baseline LightGBM with Optuna
|
| 16 |
-
|
| 17 |
-
### 1.1 Data handling
|
| 18 |
-
1. Load the Hugging Face dataset inside `train.py` exactly as today (`load_dataset("ml-jku/tox21", token=TOKEN)`).
|
| 19 |
-
2. Keep the same per-split segmentation (train/validation/test) to remain comparable with the GIN baseline.
|
| 20 |
-
3. Convert SMILES strings to RDKit `Mol` objects using the existing cleaners in `src/preprocess.py`. For the baseline, we can featurize molecules with a minimal descriptor set (e.g., RDKit physicochemical descriptors) while fingerprints are being implemented.
|
| 21 |
-
|
| 22 |
-
### 1.2 Baseline features
|
| 23 |
-
Use easily-computed descriptors such as:
|
| 24 |
-
- Molecular weight, logP, TPSA, number of H-bond donors/acceptors, rotatable bonds, aromatic proportion, etc.
|
| 25 |
-
- Concatenate one-hot encodings for atom count bins (C, N, O, halogens).
|
| 26 |
-
This gives a quick tabular vector per SMILES while fingerprint work is in progress.
|
| 27 |
-
|
| 28 |
-
### 1.3 Training objective
|
| 29 |
-
- **Task granularity:** Train one LightGBM binary classifier per Tox21 task (12 total). Targets remain the provided binary toxicity labels.
|
| 30 |
-
- **Metric:** ROC-AUC per task, with macro-average for reporting (mirrors leaderboard metric).
|
| 31 |
-
- **Data split:** For each task, drop rows with missing labels and perform K-fold CV (e.g., 5 folds) inside Optuna to make best use of labeled data.
|
| 32 |
-
|
| 33 |
-
### 1.4 Optuna search space
|
| 34 |
-
Within `src/lightgbm_trainer.py`, expose an `objective(trial, task_name)` that:
|
| 35 |
-
1. Samples:
|
| 36 |
-
- `learning_rate ∈ [1e-3, 0.2]` (log scale)
|
| 37 |
-
- `num_leaves ∈ [16, 256]`
|
| 38 |
-
- `max_depth ∈ [-1, 12]`
|
| 39 |
-
- `min_data_in_leaf ∈ [10, 200]`
|
| 40 |
-
- `feature_fraction ∈ [0.5, 1.0]`
|
| 41 |
-
- `bagging_fraction ∈ [0.5, 1.0]` with `bagging_freq ∈ [1, 10]`
|
| 42 |
-
- `lambda_l1`, `lambda_l2` (10^-8 to 10^1)
|
| 43 |
-
2. Trains the LightGBM model on each CV split and averages ROC-AUC.
|
| 44 |
-
3. Returns the negative mean ROC-AUC so Optuna can minimize the objective.
|
| 45 |
-
|
| 46 |
-
Persist the best hyperparameters per task into the config (or a JSON artifact) so `predict.py` can instantiate the booster with exact values. When data volume is small, Optuna’s `Study` can share the same random seed for reproducibility (`src/seed.py` can be reused).
|
| 47 |
-
|
| 48 |
-
### 1.5 Deliverables for Phase 1
|
| 49 |
-
- Updated `train.py` calling into `src/lightgbm_trainer.train_single_task(task_name, features, labels, config)`.
|
| 50 |
-
- `checkpoints/stage1_{task}.txt` boosters (even though they are “stage 1”, they form the baseline deliverable).
|
| 51 |
-
- Validation report (per-task ROC-AUC) saved to `checkpoints/metrics_stage1.json`.
|
| 52 |
-
- `predict.py` loads each per-task LightGBM model, computes baseline descriptors on-the-fly, and returns predictions.
|
| 53 |
-
|
| 54 |
-
---
|
| 55 |
-
|
| 56 |
-
## 2. Phase 2 — Fingerprint-Based Representations
|
| 57 |
-
|
| 58 |
-
### 2.1 Feature computation
|
| 59 |
-
Implement `src/features.py` with methods:
|
| 60 |
-
- `compute_ecfp(mol, radius=2, n_bits=1024)` using `GetMorganFingerprintAsBitVect`.
|
| 61 |
-
- `compute_map4(mol)` via MAP4 codebase (counts hashed patterns). Because MAP4 is computationally heavier, cache features to disk (e.g., `cache/fingerprints_{split}.npz`).
|
| 62 |
-
- `fingerprint_pipeline(smiles_list, fingerprint_type)` that accepts sanitized SMILES, constructs `Mol` objects, and returns a dense `np.ndarray`.
|
| 63 |
-
|
| 64 |
-
### 2.2 Integration
|
| 65 |
-
- Update `train.py` to choose the fingerprint type from config (e.g., `config["features"]["type"] = "ecfp"`).
|
| 66 |
-
- Align `predict.py` to call the same fingerprint builder on incoming SMILES.
|
| 67 |
-
- Maintain metadata describing fingerprint dimensionality and type in a manifest (e.g., `checkpoints/features.json`) so inference knows how to parse the stored LightGBM feature order.
|
| 68 |
-
|
| 69 |
-
### 2.3 Training flow
|
| 70 |
-
Apart from the enriched features, Phase 2 reuses the Phase 1 training loop. If resource constraints exist, we can:
|
| 71 |
-
- Run Optuna once on a representative task (e.g., NR-AhR) and reuse its best hyperparameters for all tasks; or
|
| 72 |
-
- Run Optuna briefly per task (e.g., 30 trials) and share results.
|
| 73 |
-
|
| 74 |
-
### 2.4 Deliverables
|
| 75 |
-
- Fingerprint cache builders + unit tests (small set of SMILES).
|
| 76 |
-
- Configurable training/inference that toggles between baseline descriptors and fingerprint vectors.
|
| 77 |
-
- Updated metrics comparing descriptors vs. ECFP vs. MAP4.
|
| 78 |
-
|
| 79 |
-
---
|
| 80 |
-
|
| 81 |
-
## 3. Phase 3 — Cross-Task Label Augmentation
|
| 82 |
-
|
| 83 |
-
### 3.1 Motivation
|
| 84 |
-
By incorporating predictions from other tasks, we expose LightGBM to shared toxicity patterns without building a fully joint model. This is especially valuable for underrepresented tasks where correlated labels provide additional signal.
|
| 85 |
-
|
| 86 |
-
### 3.2 Feature construction
|
| 87 |
-
Given `T = 12` tasks and fingerprint dimension `D`, the augmented features for task `k` are:
|
| 88 |
-
```
|
| 89 |
-
X_k = [fingerprint_vector (D dims), ŷ_1, …, ŷ_{k-1}, ŷ_{k+1}, …, ŷ_T]
|
| 90 |
-
```
|
| 91 |
-
where `ŷ_t` are the stage-1 predictions for task `t` on the same molecule. Use floats instead of hard labels to preserve uncertainty.
|
| 92 |
-
|
| 93 |
-
### 3.3 Implementation details
|
| 94 |
-
1. **Collect stage-1 predictions.**
|
| 95 |
-
- After Phase 2 training, run inference with each stage-1 model on every molecule in train/val/test splits.
|
| 96 |
-
- Store the `N × T` prediction matrix in `checkpoints/stage1_predictions_{split}.npz`.
|
| 97 |
-
2. **Align missing data.**
|
| 98 |
-
- If task `t` lacks a label for a molecule, mask it during stage-1 training but still compute predictions for other tasks so the feature matrix stays dense.
|
| 99 |
-
3. **Data leakage prevention.**
|
| 100 |
-
- During training, use out-of-fold predictions (OOF) for the stage-1 features so models do not see their own ground-truth labels through the augmented vector.
|
| 101 |
-
- Implementation: For each fold, train stage-1 LightGBM on K-1 folds, predict on the held-out fold, and concatenate predictions.
|
| 102 |
-
4. **Config surface.**
|
| 103 |
-
- `config["multitask"]["use_stage1_predictions"] = true/false`
|
| 104 |
-
- `config["multitask"]["prediction_source"] = "oof" | "full_train"` to switch between strict OOF features and simpler (but leakier) full-train predictions for debugging.
|
| 105 |
-
|
| 106 |
-
### 3.4 Training
|
| 107 |
-
Once augmented features are ready, rerun the single-task LightGBM training per target (`stage2`). Hyperparameter search can be narrower because fingerprints already provide a strong baseline; focus on `num_leaves`, `feature_fraction`, and regularization strength.
|
| 108 |
-
|
| 109 |
-
### 3.5 Deliverables
|
| 110 |
-
- Scripts that generate OOF prediction matrices.
|
| 111 |
-
- Updated `train.py` orchestration:
|
| 112 |
-
1. Train Stage 1 models.
|
| 113 |
-
2. Materialize cross-task prediction cache.
|
| 114 |
-
3. Train Stage 2 models from augmented features.
|
| 115 |
-
- Metrics comparing Stage 1 vs. Stage 2 per task.
|
| 116 |
-
|
| 117 |
-
---
|
| 118 |
-
|
| 119 |
-
## 4. Phase 4 — Two-Stage Training & Inference
|
| 120 |
-
|
| 121 |
-
### 4.1 Training orchestration
|
| 122 |
-
Pseudo-flow for `train.py`:
|
| 123 |
-
|
| 124 |
-
```python
|
| 125 |
-
def train(config):
|
| 126 |
-
ds = load_dataset(...)
|
| 127 |
-
mols = preprocess.standardize(ds["train"]["smiles"])
|
| 128 |
-
fp_cache = features.fingerprint_pipeline(mols, config["features"])
|
| 129 |
-
|
| 130 |
-
stage1 = StageOneTrainer(config)
|
| 131 |
-
stage1.train_all_tasks(fp_cache, labels, splits)
|
| 132 |
-
stage1.save_models("checkpoints/stage1_*.txt")
|
| 133 |
-
|
| 134 |
-
pred_cache = stage1.generate_predictions(fp_cache, splits, use_oof=True)
|
| 135 |
-
|
| 136 |
-
stage2 = StageTwoTrainer(config)
|
| 137 |
-
stage2.train_all_tasks(fp_cache, pred_cache, labels)
|
| 138 |
-
stage2.save_models("checkpoints/stage2_*.txt")
|
| 139 |
-
|
| 140 |
-
dump_metrics(stage1.metrics, stage2.metrics)
|
| 141 |
-
```
|
| 142 |
-
|
| 143 |
-
### 4.2 Inference pipeline (`predict.py`)
|
| 144 |
-
1. **Fingerprint computation:** identical to training (deterministic sanitization).
|
| 145 |
-
2. **Stage-1 pass:** Load every `stage1_{task}.txt`, predict on the incoming SMILES batch, and collect predictions.
|
| 146 |
-
3. **Stage-2 pass:** For each task `k`, build `[fingerprint, predicted_labels_except_k]` on-the-fly and evaluate the corresponding stage-2 booster.
|
| 147 |
-
4. **Output:** Return the stage-2 predictions for leaderboard submission. Optionally include stage-1 scores in the response if needed for debugging (but the official output should stick to stage-2 values).
|
| 148 |
-
|
| 149 |
-
### 4.3 Failure modes & mitigations
|
| 150 |
-
- **Unrecognized SMILES:** fall back to zeros or 0.5 predictions like the current baseline but log warnings so we can monitor failure rates.
|
| 151 |
-
- **Missing checkpoint:** raise an informative exception instructing users to rerun `train.py`.
|
| 152 |
-
- **Performance drift:** store SHA or timestamp metadata with checkpoints to trace which training configuration produced a given model.
|
| 153 |
-
|
| 154 |
-
---
|
| 155 |
-
|
| 156 |
-
## 5. Configuration & Experiment Tracking
|
| 157 |
-
Proposed structure for `config/config.json`:
|
| 158 |
-
|
| 159 |
-
```json
|
| 160 |
-
{
|
| 161 |
-
"seed": 42,
|
| 162 |
-
"features": {
|
| 163 |
-
"type": "ecfp",
|
| 164 |
-
"radius": 2,
|
| 165 |
-
"n_bits": 1024,
|
| 166 |
-
"use_counts": false
|
| 167 |
-
},
|
| 168 |
-
"training": {
|
| 169 |
-
"n_folds": 5,
|
| 170 |
-
"n_optuna_trials": 50,
|
| 171 |
-
"lightgbm_params": {
|
| 172 |
-
"objective": "binary",
|
| 173 |
-
"metric": "auc",
|
| 174 |
-
"verbosity": -1
|
| 175 |
-
}
|
| 176 |
-
},
|
| 177 |
-
"multitask": {
|
| 178 |
-
"enabled": true,
|
| 179 |
-
"use_stage1_predictions": true,
|
| 180 |
-
"prediction_source": "oof"
|
| 181 |
-
}
|
| 182 |
-
}
|
| 183 |
-
```
|
| 184 |
-
|
| 185 |
-
Track experiment results in `checkpoints/experiments.csv` with columns `[timestamp, fingerprint, stage, task, auc, params_hash]`.
|
| 186 |
-
|
| 187 |
-
---
|
| 188 |
-
|
| 189 |
-
## 6. Testing & Validation
|
| 190 |
-
- **Unit tests:** Ensure fingerprint builders reproduce known vectors (compare with RDKit reference) and that cross-task feature assembly drops the correct task column.
|
| 191 |
-
- **Integration tests:** Small toy dataset (3 tasks, <50 samples) to run the full Stage1→Stage2 pipeline quickly. Assert shapes of caches and that inference matches training predictions.
|
| 192 |
-
- **Performance tracking:** Plot per-task ROC-AUC improvements by phase to confirm each enhancement adds value.
|
| 193 |
-
|
| 194 |
-
---
|
| 195 |
-
|
| 196 |
-
## 7. Suggested Implementation Milestones
|
| 197 |
-
1. **M1:** Skeleton LightGBM trainer + Optuna integration (Phase 1). ✓
|
| 198 |
-
2. **M2:** Fingerprint computation module with caching + updated training/inference (Phase 2).
|
| 199 |
-
3. **M3:** Stage-1 prediction cache + feature augmentation (Phase 3).
|
| 200 |
-
4. **M4:** End-to-end Stage1→Stage2 orchestration, packaging of checkpoints, and inference updates (Phase 4).
|
| 201 |
-
5. **M5:** Documentation + automated tests to guard against regressions.
|
| 202 |
-
|
| 203 |
-
This phased roadmap keeps the leaderboard interface intact while progressively increasing the modeling capacity from simple descriptors to multitask-enhanced fingerprints.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|