mschuh commited on
Commit
0dd7279
·
verified ·
1 Parent(s): 759324e

Delete docs

Browse files
Files changed (1) hide show
  1. docs/proposed_lightgbm_framework.md +0 -203
docs/proposed_lightgbm_framework.md DELETED
@@ -1,203 +0,0 @@
1
- # LightGBM-Based Multitask Workflow for Tox21
2
-
3
- This document proposes a stepwise plan to replace the current GIN baseline (`train.py`, `predict.py`, `src/`) with a Gradient Boosting pipeline that remains compatible with the leaderboard I/O contract. Each phase can be validated independently before moving to the next, ensuring we have working training and inference artifacts at all times.
4
-
5
- ---
6
-
7
- ## 0. Repository Integration Checklist
8
- - **Entry-points stay the same.** `train.py` must continue to train from `config/config.json` and drop an inference-ready artifact into `checkpoints/`. `predict.py` must keep the `predict(smiles_list)` signature and return the nested `{smiles: {target: score}}` mapping.
9
- - **New modules.** Introduce `src/features.py` (fingerprints & caching), `src/lightgbm_trainer.py` (shared utilities for training/evaluation), and `src/stage_two.py` (cross-task augmentation logic). Keep `src/preprocess.py` for SMILES standardization + RDKit `Mol` construction so inference stays aligned with training.
10
- - **Dependencies.** Add `lightgbm`, `optuna`, `rdkit-pypi`, and optionally `map4` or `map4` reference code to `requirements.txt`. Verify any native dependencies are supported by the Spaces environment.
11
- - **Artifacts.** Store per-task boosters as `checkpoints/stage1_{task}.txt` and `checkpoints/stage2_{task}.txt` (LightGBM text dumps). Derived predictions (e.g., stage-1 OOF matrices) should live under `checkpoints/cache/` or `/tmp` during training, but inference must rely only on checkpoint files generated by `train.py`.
12
-
13
- ---
14
-
15
- ## 1. Phase 1 — Baseline LightGBM with Optuna
16
-
17
- ### 1.1 Data handling
18
- 1. Load the Hugging Face dataset inside `train.py` exactly as today (`load_dataset("ml-jku/tox21", token=TOKEN)`).
19
- 2. Keep the same per-split segmentation (train/validation/test) to remain comparable with the GIN baseline.
20
- 3. Convert SMILES strings to RDKit `Mol` objects using the existing cleaners in `src/preprocess.py`. For the baseline, we can featurize molecules with a minimal descriptor set (e.g., RDKit physicochemical descriptors) while fingerprints are being implemented.
21
-
22
- ### 1.2 Baseline features
23
- Use easily-computed descriptors such as:
24
- - Molecular weight, logP, TPSA, number of H-bond donors/acceptors, rotatable bonds, aromatic proportion, etc.
25
- - Concatenate one-hot encodings for atom count bins (C, N, O, halogens).
26
- This gives a quick tabular vector per SMILES while fingerprint work is in progress.
27
-
28
- ### 1.3 Training objective
29
- - **Task granularity:** Train one LightGBM binary classifier per Tox21 task (12 total). Targets remain the provided binary toxicity labels.
30
- - **Metric:** ROC-AUC per task, with macro-average for reporting (mirrors leaderboard metric).
31
- - **Data split:** For each task, drop rows with missing labels and perform K-fold CV (e.g., 5 folds) inside Optuna to make best use of labeled data.
32
-
33
- ### 1.4 Optuna search space
34
- Within `src/lightgbm_trainer.py`, expose an `objective(trial, task_name)` that:
35
- 1. Samples:
36
- - `learning_rate ∈ [1e-3, 0.2]` (log scale)
37
- - `num_leaves ∈ [16, 256]`
38
- - `max_depth ∈ [-1, 12]`
39
- - `min_data_in_leaf ∈ [10, 200]`
40
- - `feature_fraction ∈ [0.5, 1.0]`
41
- - `bagging_fraction ∈ [0.5, 1.0]` with `bagging_freq ∈ [1, 10]`
42
- - `lambda_l1`, `lambda_l2` (10^-8 to 10^1)
43
- 2. Trains the LightGBM model on each CV split and averages ROC-AUC.
44
- 3. Returns the negative mean ROC-AUC so Optuna can minimize the objective.
45
-
46
- Persist the best hyperparameters per task into the config (or a JSON artifact) so `predict.py` can instantiate the booster with exact values. When data volume is small, Optuna’s `Study` can share the same random seed for reproducibility (`src/seed.py` can be reused).
47
-
48
- ### 1.5 Deliverables for Phase 1
49
- - Updated `train.py` calling into `src/lightgbm_trainer.train_single_task(task_name, features, labels, config)`.
50
- - `checkpoints/stage1_{task}.txt` boosters (even though they are “stage 1”, they form the baseline deliverable).
51
- - Validation report (per-task ROC-AUC) saved to `checkpoints/metrics_stage1.json`.
52
- - `predict.py` loads each per-task LightGBM model, computes baseline descriptors on-the-fly, and returns predictions.
53
-
54
- ---
55
-
56
- ## 2. Phase 2 — Fingerprint-Based Representations
57
-
58
- ### 2.1 Feature computation
59
- Implement `src/features.py` with methods:
60
- - `compute_ecfp(mol, radius=2, n_bits=1024)` using `GetMorganFingerprintAsBitVect`.
61
- - `compute_map4(mol)` via MAP4 codebase (counts hashed patterns). Because MAP4 is computationally heavier, cache features to disk (e.g., `cache/fingerprints_{split}.npz`).
62
- - `fingerprint_pipeline(smiles_list, fingerprint_type)` that accepts sanitized SMILES, constructs `Mol` objects, and returns a dense `np.ndarray`.
63
-
64
- ### 2.2 Integration
65
- - Update `train.py` to choose the fingerprint type from config (e.g., `config["features"]["type"] = "ecfp"`).
66
- - Align `predict.py` to call the same fingerprint builder on incoming SMILES.
67
- - Maintain metadata describing fingerprint dimensionality and type in a manifest (e.g., `checkpoints/features.json`) so inference knows how to parse the stored LightGBM feature order.
68
-
69
- ### 2.3 Training flow
70
- Apart from the enriched features, Phase 2 reuses the Phase 1 training loop. If resource constraints exist, we can:
71
- - Run Optuna once on a representative task (e.g., NR-AhR) and reuse its best hyperparameters for all tasks; or
72
- - Run Optuna briefly per task (e.g., 30 trials) and share results.
73
-
74
- ### 2.4 Deliverables
75
- - Fingerprint cache builders + unit tests (small set of SMILES).
76
- - Configurable training/inference that toggles between baseline descriptors and fingerprint vectors.
77
- - Updated metrics comparing descriptors vs. ECFP vs. MAP4.
78
-
79
- ---
80
-
81
- ## 3. Phase 3 — Cross-Task Label Augmentation
82
-
83
- ### 3.1 Motivation
84
- By incorporating predictions from other tasks, we expose LightGBM to shared toxicity patterns without building a fully joint model. This is especially valuable for underrepresented tasks where correlated labels provide additional signal.
85
-
86
- ### 3.2 Feature construction
87
- Given `T = 12` tasks and fingerprint dimension `D`, the augmented features for task `k` are:
88
- ```
89
- X_k = [fingerprint_vector (D dims), ŷ_1, …, ŷ_{k-1}, ŷ_{k+1}, …, ŷ_T]
90
- ```
91
- where `ŷ_t` are the stage-1 predictions for task `t` on the same molecule. Use floats instead of hard labels to preserve uncertainty.
92
-
93
- ### 3.3 Implementation details
94
- 1. **Collect stage-1 predictions.**
95
- - After Phase 2 training, run inference with each stage-1 model on every molecule in train/val/test splits.
96
- - Store the `N × T` prediction matrix in `checkpoints/stage1_predictions_{split}.npz`.
97
- 2. **Align missing data.**
98
- - If task `t` lacks a label for a molecule, mask it during stage-1 training but still compute predictions for other tasks so the feature matrix stays dense.
99
- 3. **Data leakage prevention.**
100
- - During training, use out-of-fold predictions (OOF) for the stage-1 features so models do not see their own ground-truth labels through the augmented vector.
101
- - Implementation: For each fold, train stage-1 LightGBM on K-1 folds, predict on the held-out fold, and concatenate predictions.
102
- 4. **Config surface.**
103
- - `config["multitask"]["use_stage1_predictions"] = true/false`
104
- - `config["multitask"]["prediction_source"] = "oof" | "full_train"` to switch between strict OOF features and simpler (but leakier) full-train predictions for debugging.
105
-
106
- ### 3.4 Training
107
- Once augmented features are ready, rerun the single-task LightGBM training per target (`stage2`). Hyperparameter search can be narrower because fingerprints already provide a strong baseline; focus on `num_leaves`, `feature_fraction`, and regularization strength.
108
-
109
- ### 3.5 Deliverables
110
- - Scripts that generate OOF prediction matrices.
111
- - Updated `train.py` orchestration:
112
- 1. Train Stage 1 models.
113
- 2. Materialize cross-task prediction cache.
114
- 3. Train Stage 2 models from augmented features.
115
- - Metrics comparing Stage 1 vs. Stage 2 per task.
116
-
117
- ---
118
-
119
- ## 4. Phase 4 — Two-Stage Training & Inference
120
-
121
- ### 4.1 Training orchestration
122
- Pseudo-flow for `train.py`:
123
-
124
- ```python
125
- def train(config):
126
- ds = load_dataset(...)
127
- mols = preprocess.standardize(ds["train"]["smiles"])
128
- fp_cache = features.fingerprint_pipeline(mols, config["features"])
129
-
130
- stage1 = StageOneTrainer(config)
131
- stage1.train_all_tasks(fp_cache, labels, splits)
132
- stage1.save_models("checkpoints/stage1_*.txt")
133
-
134
- pred_cache = stage1.generate_predictions(fp_cache, splits, use_oof=True)
135
-
136
- stage2 = StageTwoTrainer(config)
137
- stage2.train_all_tasks(fp_cache, pred_cache, labels)
138
- stage2.save_models("checkpoints/stage2_*.txt")
139
-
140
- dump_metrics(stage1.metrics, stage2.metrics)
141
- ```
142
-
143
- ### 4.2 Inference pipeline (`predict.py`)
144
- 1. **Fingerprint computation:** identical to training (deterministic sanitization).
145
- 2. **Stage-1 pass:** Load every `stage1_{task}.txt`, predict on the incoming SMILES batch, and collect predictions.
146
- 3. **Stage-2 pass:** For each task `k`, build `[fingerprint, predicted_labels_except_k]` on-the-fly and evaluate the corresponding stage-2 booster.
147
- 4. **Output:** Return the stage-2 predictions for leaderboard submission. Optionally include stage-1 scores in the response if needed for debugging (but the official output should stick to stage-2 values).
148
-
149
- ### 4.3 Failure modes & mitigations
150
- - **Unrecognized SMILES:** fall back to zeros or 0.5 predictions like the current baseline but log warnings so we can monitor failure rates.
151
- - **Missing checkpoint:** raise an informative exception instructing users to rerun `train.py`.
152
- - **Performance drift:** store SHA or timestamp metadata with checkpoints to trace which training configuration produced a given model.
153
-
154
- ---
155
-
156
- ## 5. Configuration & Experiment Tracking
157
- Proposed structure for `config/config.json`:
158
-
159
- ```json
160
- {
161
- "seed": 42,
162
- "features": {
163
- "type": "ecfp",
164
- "radius": 2,
165
- "n_bits": 1024,
166
- "use_counts": false
167
- },
168
- "training": {
169
- "n_folds": 5,
170
- "n_optuna_trials": 50,
171
- "lightgbm_params": {
172
- "objective": "binary",
173
- "metric": "auc",
174
- "verbosity": -1
175
- }
176
- },
177
- "multitask": {
178
- "enabled": true,
179
- "use_stage1_predictions": true,
180
- "prediction_source": "oof"
181
- }
182
- }
183
- ```
184
-
185
- Track experiment results in `checkpoints/experiments.csv` with columns `[timestamp, fingerprint, stage, task, auc, params_hash]`.
186
-
187
- ---
188
-
189
- ## 6. Testing & Validation
190
- - **Unit tests:** Ensure fingerprint builders reproduce known vectors (compare with RDKit reference) and that cross-task feature assembly drops the correct task column.
191
- - **Integration tests:** Small toy dataset (3 tasks, <50 samples) to run the full Stage1→Stage2 pipeline quickly. Assert shapes of caches and that inference matches training predictions.
192
- - **Performance tracking:** Plot per-task ROC-AUC improvements by phase to confirm each enhancement adds value.
193
-
194
- ---
195
-
196
- ## 7. Suggested Implementation Milestones
197
- 1. **M1:** Skeleton LightGBM trainer + Optuna integration (Phase 1). ✓
198
- 2. **M2:** Fingerprint computation module with caching + updated training/inference (Phase 2).
199
- 3. **M3:** Stage-1 prediction cache + feature augmentation (Phase 3).
200
- 4. **M4:** End-to-end Stage1→Stage2 orchestration, packaging of checkpoints, and inference updates (Phase 4).
201
- 5. **M5:** Documentation + automated tests to guard against regressions.
202
-
203
- This phased roadmap keeps the leaderboard interface intact while progressively increasing the modeling capacity from simple descriptors to multitask-enhanced fingerprints.