mschuh commited on
Commit
94b1553
·
verified ·
1 Parent(s): 8fae56b

Added first version

Browse files
.example.env ADDED
@@ -0,0 +1 @@
 
 
1
+ TOKEN=example_token
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ predict copy.py
2
+ hp_search/logs/*
3
+ hp_search/models/*
4
+ __pycache__
5
+ .env
6
+ notes.txt
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM python:3.11.4
5
+
6
+ RUN useradd -m -u 1000 user
7
+ USER user
8
+ ENV PATH="/home/user/.local/bin:$PATH"
9
+
10
+ WORKDIR /app
11
+
12
+ COPY --chown=user ./requirements.txt requirements.txt
13
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
14
+
15
+ COPY --chown=user . /app
16
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
LICENSE ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution-NonCommercial 4.0 International Public
58
+ License
59
+
60
+ By exercising the Licensed Rights (defined below), You accept and agree
61
+ to be bound by the terms and conditions of this Creative Commons
62
+ Attribution-NonCommercial 4.0 International Public License ("Public
63
+ License"). To the extent this Public License may be interpreted as a
64
+ contract, You are granted the Licensed Rights in consideration of Your
65
+ acceptance of these terms and conditions, and the Licensor grants You
66
+ such rights in consideration of benefits the Licensor receives from
67
+ making the Licensed Material available under these terms and
68
+ conditions.
69
+
70
+
71
+ Section 1 -- Definitions.
72
+
73
+ a. Adapted Material means material subject to Copyright and Similar
74
+ Rights that is derived from or based upon the Licensed Material
75
+ and in which the Licensed Material is translated, altered,
76
+ arranged, transformed, or otherwise modified in a manner requiring
77
+ permission under the Copyright and Similar Rights held by the
78
+ Licensor. For purposes of this Public License, where the Licensed
79
+ Material is a musical work, performance, or sound recording,
80
+ Adapted Material is always produced where the Licensed Material is
81
+ synched in timed relation with a moving image.
82
+
83
+ b. Adapter's License means the license You apply to Your Copyright
84
+ and Similar Rights in Your contributions to Adapted Material in
85
+ accordance with the terms and conditions of this Public License.
86
+
87
+ c. Copyright and Similar Rights means copyright and/or similar rights
88
+ closely related to copyright including, without limitation,
89
+ performance, broadcast, sound recording, and Sui Generis Database
90
+ Rights, without regard to how the rights are labeled or
91
+ categorized. For purposes of this Public License, the rights
92
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
93
+ Rights.
94
+ d. Effective Technological Measures means those measures that, in the
95
+ absence of proper authority, may not be circumvented under laws
96
+ fulfilling obligations under Article 11 of the WIPO Copyright
97
+ Treaty adopted on December 20, 1996, and/or similar international
98
+ agreements.
99
+
100
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
101
+ any other exception or limitation to Copyright and Similar Rights
102
+ that applies to Your use of the Licensed Material.
103
+
104
+ f. Licensed Material means the artistic or literary work, database,
105
+ or other material to which the Licensor applied this Public
106
+ License.
107
+
108
+ g. Licensed Rights means the rights granted to You subject to the
109
+ terms and conditions of this Public License, which are limited to
110
+ all Copyright and Similar Rights that apply to Your use of the
111
+ Licensed Material and that the Licensor has authority to license.
112
+
113
+ h. Licensor means the individual(s) or entity(ies) granting rights
114
+ under this Public License.
115
+
116
+ i. NonCommercial means not primarily intended for or directed towards
117
+ commercial advantage or monetary compensation. For purposes of
118
+ this Public License, the exchange of the Licensed Material for
119
+ other material subject to Copyright and Similar Rights by digital
120
+ file-sharing or similar means is NonCommercial provided there is
121
+ no payment of monetary compensation in connection with the
122
+ exchange.
123
+
124
+ j. Share means to provide material to the public by any means or
125
+ process that requires permission under the Licensed Rights, such
126
+ as reproduction, public display, public performance, distribution,
127
+ dissemination, communication, or importation, and to make material
128
+ available to the public including in ways that members of the
129
+ public may access the material from a place and at a time
130
+ individually chosen by them.
131
+
132
+ k. Sui Generis Database Rights means rights other than copyright
133
+ resulting from Directive 96/9/EC of the European Parliament and of
134
+ the Council of 11 March 1996 on the legal protection of databases,
135
+ as amended and/or succeeded, as well as other essentially
136
+ equivalent rights anywhere in the world.
137
+
138
+ l. You means the individual or entity exercising the Licensed Rights
139
+ under this Public License. Your has a corresponding meaning.
140
+
141
+
142
+ Section 2 -- Scope.
143
+
144
+ a. License grant.
145
+
146
+ 1. Subject to the terms and conditions of this Public License,
147
+ the Licensor hereby grants You a worldwide, royalty-free,
148
+ non-sublicensable, non-exclusive, irrevocable license to
149
+ exercise the Licensed Rights in the Licensed Material to:
150
+
151
+ a. reproduce and Share the Licensed Material, in whole or
152
+ in part, for NonCommercial purposes only; and
153
+
154
+ b. produce, reproduce, and Share Adapted Material for
155
+ NonCommercial purposes only.
156
+
157
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
158
+ Exceptions and Limitations apply to Your use, this Public
159
+ License does not apply, and You do not need to comply with
160
+ its terms and conditions.
161
+
162
+ 3. Term. The term of this Public License is specified in Section
163
+ 6(a).
164
+
165
+ 4. Media and formats; technical modifications allowed. The
166
+ Licensor authorizes You to exercise the Licensed Rights in
167
+ all media and formats whether now known or hereafter created,
168
+ and to make technical modifications necessary to do so. The
169
+ Licensor waives and/or agrees not to assert any right or
170
+ authority to forbid You from making technical modifications
171
+ necessary to exercise the Licensed Rights, including
172
+ technical modifications necessary to circumvent Effective
173
+ Technological Measures. For purposes of this Public License,
174
+ simply making modifications authorized by this Section 2(a)
175
+ (4) never produces Adapted Material.
176
+
177
+ 5. Downstream recipients.
178
+
179
+ a. Offer from the Licensor -- Licensed Material. Every
180
+ recipient of the Licensed Material automatically
181
+ receives an offer from the Licensor to exercise the
182
+ Licensed Rights under the terms and conditions of this
183
+ Public License.
184
+
185
+ b. No downstream restrictions. You may not offer or impose
186
+ any additional or different terms or conditions on, or
187
+ apply any Effective Technological Measures to, the
188
+ Licensed Material if doing so restricts exercise of the
189
+ Licensed Rights by any recipient of the Licensed
190
+ Material.
191
+
192
+ 6. No endorsement. Nothing in this Public License constitutes or
193
+ may be construed as permission to assert or imply that You
194
+ are, or that Your use of the Licensed Material is, connected
195
+ with, or sponsored, endorsed, or granted official status by,
196
+ the Licensor or others designated to receive attribution as
197
+ provided in Section 3(a)(1)(A)(i).
198
+
199
+ b. Other rights.
200
+
201
+ 1. Moral rights, such as the right of integrity, are not
202
+ licensed under this Public License, nor are publicity,
203
+ privacy, and/or other similar personality rights; however, to
204
+ the extent possible, the Licensor waives and/or agrees not to
205
+ assert any such rights held by the Licensor to the limited
206
+ extent necessary to allow You to exercise the Licensed
207
+ Rights, but not otherwise.
208
+
209
+ 2. Patent and trademark rights are not licensed under this
210
+ Public License.
211
+
212
+ 3. To the extent possible, the Licensor waives any right to
213
+ collect royalties from You for the exercise of the Licensed
214
+ Rights, whether directly or through a collecting society
215
+ under any voluntary or waivable statutory or compulsory
216
+ licensing scheme. In all other cases the Licensor expressly
217
+ reserves any right to collect such royalties, including when
218
+ the Licensed Material is used other than for NonCommercial
219
+ purposes.
220
+
221
+
222
+ Section 3 -- License Conditions.
223
+
224
+ Your exercise of the Licensed Rights is expressly made subject to the
225
+ following conditions.
226
+
227
+ a. Attribution.
228
+
229
+ 1. If You Share the Licensed Material (including in modified
230
+ form), You must:
231
+
232
+ a. retain the following if it is supplied by the Licensor
233
+ with the Licensed Material:
234
+
235
+ i. identification of the creator(s) of the Licensed
236
+ Material and any others designated to receive
237
+ attribution, in any reasonable manner requested by
238
+ the Licensor (including by pseudonym if
239
+ designated);
240
+
241
+ ii. a copyright notice;
242
+
243
+ iii. a notice that refers to this Public License;
244
+
245
+ iv. a notice that refers to the disclaimer of
246
+ warranties;
247
+
248
+ v. a URI or hyperlink to the Licensed Material to the
249
+ extent reasonably practicable;
250
+
251
+ b. indicate if You modified the Licensed Material and
252
+ retain an indication of any previous modifications; and
253
+
254
+ c. indicate the Licensed Material is licensed under this
255
+ Public License, and include the text of, or the URI or
256
+ hyperlink to, this Public License.
257
+
258
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
259
+ reasonable manner based on the medium, means, and context in
260
+ which You Share the Licensed Material. For example, it may be
261
+ reasonable to satisfy the conditions by providing a URI or
262
+ hyperlink to a resource that includes the required
263
+ information.
264
+
265
+ 3. If requested by the Licensor, You must remove any of the
266
+ information required by Section 3(a)(1)(A) to the extent
267
+ reasonably practicable.
268
+
269
+ 4. If You Share Adapted Material You produce, the Adapter's
270
+ License You apply must not prevent recipients of the Adapted
271
+ Material from complying with this Public License.
272
+
273
+
274
+ Section 4 -- Sui Generis Database Rights.
275
+
276
+ Where the Licensed Rights include Sui Generis Database Rights that
277
+ apply to Your use of the Licensed Material:
278
+
279
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
280
+ to extract, reuse, reproduce, and Share all or a substantial
281
+ portion of the contents of the database for NonCommercial purposes
282
+ only;
283
+
284
+ b. if You include all or a substantial portion of the database
285
+ contents in a database in which You have Sui Generis Database
286
+ Rights, then the database in which You have Sui Generis Database
287
+ Rights (but not its individual contents) is Adapted Material; and
288
+
289
+ c. You must comply with the conditions in Section 3(a) if You Share
290
+ all or a substantial portion of the contents of the database.
291
+
292
+ For the avoidance of doubt, this Section 4 supplements and does not
293
+ replace Your obligations under this Public License where the Licensed
294
+ Rights include other Copyright and Similar Rights.
295
+
296
+
297
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
298
+
299
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
300
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
301
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
302
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
303
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
304
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
305
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
306
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
307
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
308
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
309
+
310
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
311
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
312
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
313
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
314
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
315
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
316
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
317
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
318
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
319
+
320
+ c. The disclaimer of warranties and limitation of liability provided
321
+ above shall be interpreted in a manner that, to the extent
322
+ possible, most closely approximates an absolute disclaimer and
323
+ waiver of all liability.
324
+
325
+
326
+ Section 6 -- Term and Termination.
327
+
328
+ a. This Public License applies for the term of the Copyright and
329
+ Similar Rights licensed here. However, if You fail to comply with
330
+ this Public License, then Your rights under this Public License
331
+ terminate automatically.
332
+
333
+ b. Where Your right to use the Licensed Material has terminated under
334
+ Section 6(a), it reinstates:
335
+
336
+ 1. automatically as of the date the violation is cured, provided
337
+ it is cured within 30 days of Your discovery of the
338
+ violation; or
339
+
340
+ 2. upon express reinstatement by the Licensor.
341
+
342
+ For the avoidance of doubt, this Section 6(b) does not affect any
343
+ right the Licensor may have to seek remedies for Your violations
344
+ of this Public License.
345
+
346
+ c. For the avoidance of doubt, the Licensor may also offer the
347
+ Licensed Material under separate terms or conditions or stop
348
+ distributing the Licensed Material at any time; however, doing so
349
+ will not terminate this Public License.
350
+
351
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
352
+ License.
353
+
354
+
355
+ Section 7 -- Other Terms and Conditions.
356
+
357
+ a. The Licensor shall not be bound by any additional or different
358
+ terms or conditions communicated by You unless expressly agreed.
359
+
360
+ b. Any arrangements, understandings, or agreements regarding the
361
+ Licensed Material not stated herein are separate from and
362
+ independent of the terms and conditions of this Public License.
363
+
364
+
365
+ Section 8 -- Interpretation.
366
+
367
+ a. For the avoidance of doubt, this Public License does not, and
368
+ shall not be interpreted to, reduce, limit, restrict, or impose
369
+ conditions on any use of the Licensed Material that could lawfully
370
+ be made without permission under this Public License.
371
+
372
+ b. To the extent possible, if any provision of this Public License is
373
+ deemed unenforceable, it shall be automatically reformed to the
374
+ minimum extent necessary to make it enforceable. If the provision
375
+ cannot be reformed, it shall be severed from this Public License
376
+ without affecting the enforceability of the remaining terms and
377
+ conditions.
378
+
379
+ c. No term or condition of this Public License will be waived and no
380
+ failure to comply consented to unless expressly agreed to by the
381
+ Licensor.
382
+
383
+ d. Nothing in this Public License constitutes or may be interpreted
384
+ as a limitation upon, or waiver of, any privileges and immunities
385
+ that apply to the Licensor or You, including from the legal
386
+ processes of any jurisdiction or authority.
387
+
388
+ =======================================================================
389
+
390
+ Creative Commons is not a party to its public
391
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
392
+ its public licenses to material it publishes and in those instances
393
+ will be considered the “Licensor.” The text of the Creative Commons
394
+ public licenses is dedicated to the public domain under the CC0 Public
395
+ Domain Dedication. Except for the limited purpose of indicating that
396
+ material is shared under a Creative Commons public license or as
397
+ otherwise permitted by the Creative Commons policies published at
398
+ creativecommons.org/policies, Creative Commons does not authorize the
399
+ use of the trademark "Creative Commons" or any other trademark or logo
400
+ of Creative Commons without its prior written consent including,
401
+ without limitation, in connection with any unauthorized modifications
402
+ to any of its public licenses or any other arrangements,
403
+ understandings, or agreements concerning use of licensed material. For
404
+ the avoidance of doubt, this paragraph does not form part of the
405
+ public licenses.
406
+
407
+ Creative Commons may be contacted at creativecommons.org.
README.md CHANGED
@@ -1,12 +1,115 @@
1
- ---
2
- title: MultiTaskTox
3
- emoji: 🦀
4
- colorFrom: blue
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 5.49.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MultiTaskTox – LightGBM Fingerprint Classifier for Tox21
2
+
3
+ MultiTaskTox is a two-stage Gradient Boosting workflow purpose-built for the [Tox21](https://huggingface.co/datasets/ml-jku/tox21) benchmark. It ingests molecular SMILES strings, converts them into high-dimensional fingerprints (ECFP or MAP4), and trains a set of LightGBM classifiers that leverage cross-task signal to improve toxicity prediction across all 12 Tox21 targets.
4
+
5
+ ## Why MultiTaskTox?
6
+
7
+ - **Deterministic preprocessing** – every SMILES string is standardized through RDKit before fingerprint generation, ensuring training and inference behave identically.
8
+ - **Optuna-tuned per-task boosters** – each toxicity endpoint receives its own LightGBM classifier, tuned directly on the provided train/validation splits.
9
+ - **Multitask enhancement** – stage two augments the fingerprint vector with the predictions of the other tasks, capturing label correlations without building a fully joint model.
10
+ - **Leaderboard-ready interface** – `train.py` produces checkpoints and metadata under `checkpoints/`, while `predict.py` exposes the required `predict(smiles_list)` signature.
11
+
12
+ ## Installation
13
+
14
+ ```bash
15
+ git clone https://huggingface.co/spaces/ml-jku/tox21_gin_classifier
16
+ cd tox21_gin_classifier
17
+ python -m venv .venv && source .venv/bin/activate
18
+ pip install --upgrade pip
19
+ pip install -r requirements.txt
20
+ ```
21
+
22
+ The requirements include RDKit, LightGBM, Optuna, and the MAP4 fingerprint package so you can switch feature types via the config.
23
+
24
+ ## Training
25
+
26
+ 1. Create a `.env` file (all Hugging Face Spaces support secrets) with your dataset token:
27
+ ```
28
+ TOKEN=hf_xxx
29
+ ```
30
+ 2. Adjust `config/config.json` if needed (fingerprint type, Optuna trial count, etc.).
31
+ 3. Run:
32
+ ```bash
33
+ python train.py
34
+ ```
35
+
36
+ ### What `train.py` does
37
+
38
+ 1. Loads the predefined `train` and `validation` splits from the Tox21 dataset.
39
+ 2. Standardizes SMILES and builds fingerprints using `src/features.py`.
40
+ 3. For each target:
41
+ - Runs Optuna to find the best LightGBM hyperparameters using the validation split as the evaluation set.
42
+ - Fits the classifier (`stage1`) and stores the model as `checkpoints/stage1/<target>.pkl`.
43
+ 4. Generates prediction matrices for both splits.
44
+ 5. If multitask mode is enabled (`config["multitask"]["enabled"]`), creates augmented features (fingerprint + other-task predictions) and trains stage-two boosters saved under `checkpoints/stage2/`.
45
+ 6. Writes metrics (`metrics_stage1.json`, `metrics_stage2.json`) and a manifest (`training_manifest.json`) describing the experiment.
46
+
47
+ ## Inference
48
+
49
+ `predict.py` exposes:
50
+
51
+ ```python
52
+ from predict import predict
53
+
54
+ smiles = ["CCO", "c1ccccc1", "CC(=O)O"]
55
+ results = predict(smiles)
56
+ ```
57
+
58
+ The function:
59
+ 1. Loads the training manifest to know which fingerprint type and checkpoints to use.
60
+ 2. Standardizes and fingerprints the SMILES on the fly.
61
+ 3. Runs stage-one LightGBM classifiers to obtain probabilistic predictions.
62
+ 4. If stage-two models exist, augments the features with cross-task predictions and runs the multitask models.
63
+ 5. Returns `{smiles: {target_name: probability}}` with values in `[0, 1]`. Invalid SMILES fall back to `0.5`.
64
+
65
+ ## Configuration Overview (`config/config.json`)
66
+
67
+ ```json
68
+ {
69
+ "seed": 42,
70
+ "dataset": {"name": "ml-jku/tox21"},
71
+ "features": {
72
+ "type": "ecfp",
73
+ "radius": 2,
74
+ "n_bits": 1024,
75
+ "map4_dim": 1024,
76
+ "cache_dir": "./checkpoints/cache"
77
+ },
78
+ "training": {
79
+ "optuna_trials": 40,
80
+ "boosting_rounds": 1500,
81
+ "early_stopping_rounds": 100,
82
+ "lightgbm_params": {
83
+ "objective": "binary",
84
+ "metric": "auc",
85
+ "verbosity": -1
86
+ }
87
+ },
88
+ "multitask": {"enabled": true},
89
+ "output": {"checkpoint_dir": "./checkpoints"}
90
+ }
91
+ ```
92
+
93
+ - Switch `features.type` to `"map4"` to use MAP4 fingerprints (installed by default).
94
+ - Disable multitask behavior by setting `"multitask": {"enabled": false}`.
95
+ - Increase `optuna_trials` for a more exhaustive search if compute allows.
96
+
97
+ ## Repository Layout
98
+
99
+ - `train.py` – orchestrates the full training workflow (feature generation, Optuna tuning, stage-one and stage-two models).
100
+ - `predict.py` – leaderboard-friendly inference function that loads the checkpoints generated by `train.py`.
101
+ - `src/preprocess.py` – dataset loading and SMILES standardization helpers.
102
+ - `src/features.py` – fingerprint computation with disk caching.
103
+ - `src/lightgbm_trainer.py` – LightGBM + Optuna utilities for stage-one training.
104
+ - `src/stage_two.py` – multitask feature augmentation and model training.
105
+ - `src/constants.py`, `src/seed.py` – shared utilities.
106
+ - `docs/proposed_lightgbm_framework.md` – detailed design notes for the workflow.
107
+ - `checkpoints/` – default output directory containing models, metrics, caches, and the training manifest used at inference time.
108
+
109
+ ## Tips
110
+
111
+ - Training relies on the `TOKEN` environment variable to access the Tox21 dataset on Hugging Face. Locally you can omit it if the dataset is public for your account.
112
+ - MAP4 fingerprints are more expensive to compute; enable the cache directory to avoid recomputation across runs.
113
+ - Use the saved metrics files to compare stage-one vs. stage-two AUCs and to trace which configuration produced a set of checkpoints.
114
+
115
+ Happy modeling! If you extend MultiTaskTox (new fingerprints, alternative learners, etc.), keep the `predict(smiles)` contract intact so your Space remains leaderboard compatible.
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is the main entry point for the FastAPI application.
3
+ The app handles the request to predict toxicity for a list of SMILES strings.
4
+ """
5
+
6
+ # ---------------------------------------------------------------------------------------
7
+ # Dependencies and global variable definition
8
+ import os
9
+ from typing import List, Dict, Optional
10
+ from fastapi import FastAPI, Header, HTTPException
11
+ from pydantic import BaseModel, Field
12
+
13
+ from predict import predict as predict_func
14
+
15
+ API_KEY = os.getenv("API_KEY") # set via Space Secrets
16
+
17
+
18
+ # ---------------------------------------------------------------------------------------
19
+ class Request(BaseModel):
20
+ smiles: List[str] = Field(min_items=1, max_items=1000)
21
+
22
+
23
+ class Response(BaseModel):
24
+ predictions: dict
25
+ model_info: Dict[str, str] = {}
26
+
27
+
28
+ app = FastAPI(title="toxicity-api")
29
+
30
+
31
+ @app.get("/")
32
+ def root():
33
+ return {
34
+ "message": "Toxicity Prediction API",
35
+ "endpoints": {
36
+ "/metadata": "GET - API metadata and capabilities",
37
+ "/healthz": "GET - Health check",
38
+ "/predict": "POST - Predict toxicity for SMILES",
39
+ },
40
+ "usage": "Send POST to /predict with {'smiles': ['your_smiles_here']}",
41
+ }
42
+
43
+
44
+ @app.get("/metadata")
45
+ def metadata():
46
+ return {
47
+ "name": "Tox21 GIN classifier",
48
+ "version": "1.0.0",
49
+ "tox_endpoints": [
50
+ "NR-AR",
51
+ "NR-AR-LBD",
52
+ "NR-AhR",
53
+ "NR-Aromatase",
54
+ "NR-ER",
55
+ "NR-ER-LBD",
56
+ "NR-PPAR-gamma",
57
+ "SR-ARE",
58
+ "SR-ATAD5",
59
+ "SR-HSE",
60
+ "SR-MMP",
61
+ "SR-p53",
62
+ ],
63
+ }
64
+
65
+
66
+ @app.get("/healthz")
67
+ def healthz():
68
+ return {"ok": True}
69
+
70
+
71
+ @app.post("/predict", response_model=Response)
72
+ def predict(request: Request):
73
+ predictions = predict_func(request.smiles)
74
+ return {
75
+ "predictions": predictions,
76
+ "model_info": {"name": "MultiTaskTox", "version": "0.0.1"},
77
+ }
config/config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "seed": 42,
3
+ "dataset": {
4
+ "name": "ml-jku/tox21"
5
+ },
6
+ "features": {
7
+ "type": "ecfp",
8
+ "radius": 2,
9
+ "n_bits": 1024,
10
+ "use_counts": false,
11
+ "map4_dim": 1024,
12
+ "cache_dir": "./checkpoints/cache"
13
+ },
14
+ "training": {
15
+ "optuna_trials": 40,
16
+ "boosting_rounds": 1500,
17
+ "early_stopping_rounds": 100,
18
+ "lightgbm_params": {
19
+ "objective": "binary",
20
+ "metric": "auc",
21
+ "verbosity": -1
22
+ }
23
+ },
24
+ "multitask": {
25
+ "enabled": true,
26
+ "prediction_source": "oof"
27
+ },
28
+ "output": {
29
+ "checkpoint_dir": "./checkpoints"
30
+ }
31
+ }
docs/proposed_lightgbm_framework.md ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LightGBM-Based Multitask Workflow for Tox21
2
+
3
+ This document proposes a stepwise plan to replace the current GIN baseline (`train.py`, `predict.py`, `src/`) with a Gradient Boosting pipeline that remains compatible with the leaderboard I/O contract. Each phase can be validated independently before moving to the next, ensuring we have working training and inference artifacts at all times.
4
+
5
+ ---
6
+
7
+ ## 0. Repository Integration Checklist
8
+ - **Entry-points stay the same.** `train.py` must continue to train from `config/config.json` and drop an inference-ready artifact into `checkpoints/`. `predict.py` must keep the `predict(smiles_list)` signature and return the nested `{smiles: {target: score}}` mapping.
9
+ - **New modules.** Introduce `src/features.py` (fingerprints & caching), `src/lightgbm_trainer.py` (shared utilities for training/evaluation), and `src/stage_two.py` (cross-task augmentation logic). Keep `src/preprocess.py` for SMILES standardization + RDKit `Mol` construction so inference stays aligned with training.
10
+ - **Dependencies.** Add `lightgbm`, `optuna`, `rdkit-pypi`, and optionally `map4` or `map4` reference code to `requirements.txt`. Verify any native dependencies are supported by the Spaces environment.
11
+ - **Artifacts.** Store per-task boosters as `checkpoints/stage1_{task}.txt` and `checkpoints/stage2_{task}.txt` (LightGBM text dumps). Derived predictions (e.g., stage-1 OOF matrices) should live under `checkpoints/cache/` or `/tmp` during training, but inference must rely only on checkpoint files generated by `train.py`.
12
+
13
+ ---
14
+
15
+ ## 1. Phase 1 — Baseline LightGBM with Optuna
16
+
17
+ ### 1.1 Data handling
18
+ 1. Load the Hugging Face dataset inside `train.py` exactly as today (`load_dataset("ml-jku/tox21", token=TOKEN)`).
19
+ 2. Keep the same per-split segmentation (train/validation/test) to remain comparable with the GIN baseline.
20
+ 3. Convert SMILES strings to RDKit `Mol` objects using the existing cleaners in `src/preprocess.py`. For the baseline, we can featurize molecules with a minimal descriptor set (e.g., RDKit physicochemical descriptors) while fingerprints are being implemented.
21
+
22
+ ### 1.2 Baseline features
23
+ Use easily-computed descriptors such as:
24
+ - Molecular weight, logP, TPSA, number of H-bond donors/acceptors, rotatable bonds, aromatic proportion, etc.
25
+ - Concatenate one-hot encodings for atom count bins (C, N, O, halogens).
26
+ This gives a quick tabular vector per SMILES while fingerprint work is in progress.
27
+
28
+ ### 1.3 Training objective
29
+ - **Task granularity:** Train one LightGBM binary classifier per Tox21 task (12 total). Targets remain the provided binary toxicity labels.
30
+ - **Metric:** ROC-AUC per task, with macro-average for reporting (mirrors leaderboard metric).
31
+ - **Data split:** For each task, drop rows with missing labels and perform K-fold CV (e.g., 5 folds) inside Optuna to make best use of labeled data.
32
+
33
+ ### 1.4 Optuna search space
34
+ Within `src/lightgbm_trainer.py`, expose an `objective(trial, task_name)` that:
35
+ 1. Samples:
36
+ - `learning_rate ∈ [1e-3, 0.2]` (log scale)
37
+ - `num_leaves ∈ [16, 256]`
38
+ - `max_depth ∈ [-1, 12]`
39
+ - `min_data_in_leaf ∈ [10, 200]`
40
+ - `feature_fraction ∈ [0.5, 1.0]`
41
+ - `bagging_fraction ∈ [0.5, 1.0]` with `bagging_freq ∈ [1, 10]`
42
+ - `lambda_l1`, `lambda_l2` (10^-8 to 10^1)
43
+ 2. Trains the LightGBM model on each CV split and averages ROC-AUC.
44
+ 3. Returns the negative mean ROC-AUC so Optuna can minimize the objective.
45
+
46
+ Persist the best hyperparameters per task into the config (or a JSON artifact) so `predict.py` can instantiate the booster with exact values. When data volume is small, Optuna’s `Study` can share the same random seed for reproducibility (`src/seed.py` can be reused).
47
+
48
+ ### 1.5 Deliverables for Phase 1
49
+ - Updated `train.py` calling into `src/lightgbm_trainer.train_single_task(task_name, features, labels, config)`.
50
+ - `checkpoints/stage1_{task}.txt` boosters (even though they are “stage 1”, they form the baseline deliverable).
51
+ - Validation report (per-task ROC-AUC) saved to `checkpoints/metrics_stage1.json`.
52
+ - `predict.py` loads each per-task LightGBM model, computes baseline descriptors on-the-fly, and returns predictions.
53
+
54
+ ---
55
+
56
+ ## 2. Phase 2 — Fingerprint-Based Representations
57
+
58
+ ### 2.1 Feature computation
59
+ Implement `src/features.py` with methods:
60
+ - `compute_ecfp(mol, radius=2, n_bits=1024)` using `GetMorganFingerprintAsBitVect`.
61
+ - `compute_map4(mol)` via MAP4 codebase (counts hashed patterns). Because MAP4 is computationally heavier, cache features to disk (e.g., `cache/fingerprints_{split}.npz`).
62
+ - `fingerprint_pipeline(smiles_list, fingerprint_type)` that accepts sanitized SMILES, constructs `Mol` objects, and returns a dense `np.ndarray`.
63
+
64
+ ### 2.2 Integration
65
+ - Update `train.py` to choose the fingerprint type from config (e.g., `config["features"]["type"] = "ecfp"`).
66
+ - Align `predict.py` to call the same fingerprint builder on incoming SMILES.
67
+ - Maintain metadata describing fingerprint dimensionality and type in a manifest (e.g., `checkpoints/features.json`) so inference knows how to parse the stored LightGBM feature order.
68
+
69
+ ### 2.3 Training flow
70
+ Apart from the enriched features, Phase 2 reuses the Phase 1 training loop. If resource constraints exist, we can:
71
+ - Run Optuna once on a representative task (e.g., NR-AhR) and reuse its best hyperparameters for all tasks; or
72
+ - Run Optuna briefly per task (e.g., 30 trials) and share results.
73
+
74
+ ### 2.4 Deliverables
75
+ - Fingerprint cache builders + unit tests (small set of SMILES).
76
+ - Configurable training/inference that toggles between baseline descriptors and fingerprint vectors.
77
+ - Updated metrics comparing descriptors vs. ECFP vs. MAP4.
78
+
79
+ ---
80
+
81
+ ## 3. Phase 3 — Cross-Task Label Augmentation
82
+
83
+ ### 3.1 Motivation
84
+ By incorporating predictions from other tasks, we expose LightGBM to shared toxicity patterns without building a fully joint model. This is especially valuable for underrepresented tasks where correlated labels provide additional signal.
85
+
86
+ ### 3.2 Feature construction
87
+ Given `T = 12` tasks and fingerprint dimension `D`, the augmented features for task `k` are:
88
+ ```
89
+ X_k = [fingerprint_vector (D dims), ŷ_1, …, ŷ_{k-1}, ŷ_{k+1}, …, ŷ_T]
90
+ ```
91
+ where `ŷ_t` are the stage-1 predictions for task `t` on the same molecule. Use floats instead of hard labels to preserve uncertainty.
92
+
93
+ ### 3.3 Implementation details
94
+ 1. **Collect stage-1 predictions.**
95
+ - After Phase 2 training, run inference with each stage-1 model on every molecule in train/val/test splits.
96
+ - Store the `N × T` prediction matrix in `checkpoints/stage1_predictions_{split}.npz`.
97
+ 2. **Align missing data.**
98
+ - If task `t` lacks a label for a molecule, mask it during stage-1 training but still compute predictions for other tasks so the feature matrix stays dense.
99
+ 3. **Data leakage prevention.**
100
+ - During training, use out-of-fold predictions (OOF) for the stage-1 features so models do not see their own ground-truth labels through the augmented vector.
101
+ - Implementation: For each fold, train stage-1 LightGBM on K-1 folds, predict on the held-out fold, and concatenate predictions.
102
+ 4. **Config surface.**
103
+ - `config["multitask"]["use_stage1_predictions"] = true/false`
104
+ - `config["multitask"]["prediction_source"] = "oof" | "full_train"` to switch between strict OOF features and simpler (but leakier) full-train predictions for debugging.
105
+
106
+ ### 3.4 Training
107
+ Once augmented features are ready, rerun the single-task LightGBM training per target (`stage2`). Hyperparameter search can be narrower because fingerprints already provide a strong baseline; focus on `num_leaves`, `feature_fraction`, and regularization strength.
108
+
109
+ ### 3.5 Deliverables
110
+ - Scripts that generate OOF prediction matrices.
111
+ - Updated `train.py` orchestration:
112
+ 1. Train Stage 1 models.
113
+ 2. Materialize cross-task prediction cache.
114
+ 3. Train Stage 2 models from augmented features.
115
+ - Metrics comparing Stage 1 vs. Stage 2 per task.
116
+
117
+ ---
118
+
119
+ ## 4. Phase 4 — Two-Stage Training & Inference
120
+
121
+ ### 4.1 Training orchestration
122
+ Pseudo-flow for `train.py`:
123
+
124
+ ```python
125
+ def train(config):
126
+ ds = load_dataset(...)
127
+ mols = preprocess.standardize(ds["train"]["smiles"])
128
+ fp_cache = features.fingerprint_pipeline(mols, config["features"])
129
+
130
+ stage1 = StageOneTrainer(config)
131
+ stage1.train_all_tasks(fp_cache, labels, splits)
132
+ stage1.save_models("checkpoints/stage1_*.txt")
133
+
134
+ pred_cache = stage1.generate_predictions(fp_cache, splits, use_oof=True)
135
+
136
+ stage2 = StageTwoTrainer(config)
137
+ stage2.train_all_tasks(fp_cache, pred_cache, labels)
138
+ stage2.save_models("checkpoints/stage2_*.txt")
139
+
140
+ dump_metrics(stage1.metrics, stage2.metrics)
141
+ ```
142
+
143
+ ### 4.2 Inference pipeline (`predict.py`)
144
+ 1. **Fingerprint computation:** identical to training (deterministic sanitization).
145
+ 2. **Stage-1 pass:** Load every `stage1_{task}.txt`, predict on the incoming SMILES batch, and collect predictions.
146
+ 3. **Stage-2 pass:** For each task `k`, build `[fingerprint, predicted_labels_except_k]` on-the-fly and evaluate the corresponding stage-2 booster.
147
+ 4. **Output:** Return the stage-2 predictions for leaderboard submission. Optionally include stage-1 scores in the response if needed for debugging (but the official output should stick to stage-2 values).
148
+
149
+ ### 4.3 Failure modes & mitigations
150
+ - **Unrecognized SMILES:** fall back to zeros or 0.5 predictions like the current baseline but log warnings so we can monitor failure rates.
151
+ - **Missing checkpoint:** raise an informative exception instructing users to rerun `train.py`.
152
+ - **Performance drift:** store SHA or timestamp metadata with checkpoints to trace which training configuration produced a given model.
153
+
154
+ ---
155
+
156
+ ## 5. Configuration & Experiment Tracking
157
+ Proposed structure for `config/config.json`:
158
+
159
+ ```json
160
+ {
161
+ "seed": 42,
162
+ "features": {
163
+ "type": "ecfp",
164
+ "radius": 2,
165
+ "n_bits": 1024,
166
+ "use_counts": false
167
+ },
168
+ "training": {
169
+ "n_folds": 5,
170
+ "n_optuna_trials": 50,
171
+ "lightgbm_params": {
172
+ "objective": "binary",
173
+ "metric": "auc",
174
+ "verbosity": -1
175
+ }
176
+ },
177
+ "multitask": {
178
+ "enabled": true,
179
+ "use_stage1_predictions": true,
180
+ "prediction_source": "oof"
181
+ }
182
+ }
183
+ ```
184
+
185
+ Track experiment results in `checkpoints/experiments.csv` with columns `[timestamp, fingerprint, stage, task, auc, params_hash]`.
186
+
187
+ ---
188
+
189
+ ## 6. Testing & Validation
190
+ - **Unit tests:** Ensure fingerprint builders reproduce known vectors (compare with RDKit reference) and that cross-task feature assembly drops the correct task column.
191
+ - **Integration tests:** Small toy dataset (3 tasks, <50 samples) to run the full Stage1→Stage2 pipeline quickly. Assert shapes of caches and that inference matches training predictions.
192
+ - **Performance tracking:** Plot per-task ROC-AUC improvements by phase to confirm each enhancement adds value.
193
+
194
+ ---
195
+
196
+ ## 7. Suggested Implementation Milestones
197
+ 1. **M1:** Skeleton LightGBM trainer + Optuna integration (Phase 1). ✓
198
+ 2. **M2:** Fingerprint computation module with caching + updated training/inference (Phase 2).
199
+ 3. **M3:** Stage-1 prediction cache + feature augmentation (Phase 3).
200
+ 4. **M4:** End-to-end Stage1→Stage2 orchestration, packaging of checkpoints, and inference updates (Phase 4).
201
+ 5. **M5:** Documentation + automated tests to guard against regressions.
202
+
203
+ This phased roadmap keeps the leaderboard interface intact while progressively increasing the modeling capacity from simple descriptors to multitask-enhanced fingerprints.
predict.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from functools import lru_cache
5
+ from pathlib import Path
6
+ from typing import Dict, List
7
+
8
+ import joblib
9
+ import numpy as np
10
+
11
+ from src.constants import TARGET_NAMES
12
+ from src.features import FingerprintFeaturizer
13
+ from src.seed import set_seed
14
+
15
+ BASE_PREDICTION = 0.5
16
+
17
+
18
+ @lru_cache(maxsize=1)
19
+ def _load_manifest() -> Dict:
20
+ manifest_path = Path("./checkpoints/training_manifest.json")
21
+ if not manifest_path.exists():
22
+ raise FileNotFoundError("Missing checkpoints/training_manifest.json. Run train.py first.")
23
+ with manifest_path.open("r", encoding="utf-8") as f:
24
+ manifest = json.load(f)
25
+ return manifest
26
+
27
+
28
+ @lru_cache(maxsize=2)
29
+ def _load_stage_models(stage: str):
30
+ manifest = _load_manifest()
31
+ stage_info = manifest.get(stage, {})
32
+ model_dir = stage_info.get("model_dir")
33
+ if not model_dir:
34
+ return {}
35
+ model_path = Path(model_dir)
36
+ models = {}
37
+ for target in manifest.get("target_names", TARGET_NAMES):
38
+ model_file = model_path / f"{target}.pkl"
39
+ if model_file.exists():
40
+ models[target] = joblib.load(model_file)
41
+ return models
42
+
43
+
44
+ def _compute_stage1_predictions(features: np.ndarray, target_names: List[str]) -> np.ndarray:
45
+ """Return predictions for the valid molecules from stage-1 models."""
46
+ stage1_models = _load_stage_models("stage1")
47
+ if features.shape[0] == 0:
48
+ return np.zeros((0, len(target_names)), dtype=np.float32)
49
+
50
+ predictions = np.full((features.shape[0], len(target_names)), BASE_PREDICTION, dtype=np.float32)
51
+ for idx, target in enumerate(target_names):
52
+ booster = stage1_models.get(target)
53
+ if booster is None:
54
+ continue
55
+ best_iter = getattr(booster, "best_iteration_", None)
56
+ kwargs = {"num_iteration": best_iter} if best_iter is not None else {}
57
+ preds = booster.predict_proba(features, **kwargs)[:, 1]
58
+ predictions[:, idx] = preds
59
+ return predictions
60
+
61
+
62
+ def _compute_stage2_predictions(
63
+ base_features: np.ndarray,
64
+ stage1_preds: np.ndarray,
65
+ target_names: List[str],
66
+ ) -> np.ndarray:
67
+ stage2_models = _load_stage_models("stage2")
68
+ if not stage2_models:
69
+ return stage1_preds
70
+
71
+ n_samples = base_features.shape[0]
72
+ results = np.full((n_samples, len(target_names)), BASE_PREDICTION, dtype=np.float32)
73
+ for idx, target in enumerate(target_names):
74
+ model = stage2_models.get(target)
75
+ if model is None:
76
+ results[:, idx] = stage1_preds[:, idx]
77
+ continue
78
+ augmented = np.concatenate(
79
+ [
80
+ base_features,
81
+ np.delete(stage1_preds, idx, axis=1),
82
+ ],
83
+ axis=1,
84
+ )
85
+ best_iter = getattr(model, "best_iteration_", None)
86
+ kwargs = {"num_iteration": best_iter} if best_iter is not None else {}
87
+ preds = model.predict_proba(augmented, **kwargs)[:, 1]
88
+ results[:, idx] = preds
89
+ return results
90
+
91
+
92
+ def predict(smiles_list: List[str]) -> Dict[str, Dict[str, float]]:
93
+ """
94
+ Predict toxicity targets for a list of SMILES strings.
95
+
96
+ Args:
97
+ smiles_list (list[str]): SMILES strings
98
+
99
+ Returns:
100
+ dict: {smiles: {target_name: prediction_prob}}
101
+ """
102
+ set_seed(0)
103
+ manifest = _load_manifest()
104
+ target_names = manifest.get("target_names", TARGET_NAMES)
105
+ feature_config = manifest.get("feature_config", {"type": "ecfp"})
106
+
107
+ featurizer = FingerprintFeaturizer(feature_config)
108
+ batch, features = featurizer.featurize_smiles(smiles_list)
109
+
110
+ stage1_preds = _compute_stage1_predictions(features, target_names)
111
+ stage2_preds = _compute_stage2_predictions(features, stage1_preds, target_names)
112
+
113
+ predictions: Dict[str, Dict[str, float]] = {}
114
+ valid_idx = 0
115
+
116
+ for original_smiles, is_valid in zip(smiles_list, batch.mask):
117
+ if not is_valid:
118
+ predictions[original_smiles] = {target: BASE_PREDICTION for target in target_names}
119
+ continue
120
+
121
+ row_preds = stage2_preds[valid_idx] if stage2_preds.size else np.full(len(target_names), BASE_PREDICTION)
122
+ predictions[original_smiles] = {target: float(score) for target, score in zip(target_names, row_preds)}
123
+ valid_idx += 1
124
+
125
+ return predictions
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ numpy==1.26.2
4
+ python-dotenv
5
+ pandas==2.2.2
6
+ scikit-learn==1.7.1
7
+ pydantic
8
+ rdkit-pypi
9
+ datasets
10
+ lightgbm
11
+ optuna
12
+ joblib
13
+ map4
src/constants.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TARGET_NAMES = [
2
+ "NR-AhR",
3
+ "NR-AR",
4
+ "NR-AR-LBD",
5
+ "NR-Aromatase",
6
+ "NR-ER",
7
+ "NR-ER-LBD",
8
+ "NR-PPAR-gamma",
9
+ "SR-ARE",
10
+ "SR-ATAD5",
11
+ "SR-HSE",
12
+ "SR-MMP",
13
+ "SR-p53",
14
+ ]
15
+
16
+ CANONICAL_SMILES_COLUMN = "canonical_smiles"
src/features.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Dict, Sequence
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ from rdkit import DataStructs
9
+ from rdkit.Chem import AllChem
10
+
11
+ from .constants import CANONICAL_SMILES_COLUMN
12
+ from .preprocess import MoleculeBatch, filter_dataframe_by_mask, standardize_smiles
13
+
14
+ try:
15
+ from map4 import MAP4Calculator # type: ignore
16
+ except Exception: # pragma: no cover - optional dependency
17
+ MAP4Calculator = None
18
+
19
+
20
+ class FingerprintFeaturizer:
21
+ """Compute molecular fingerprints with optional caching."""
22
+
23
+ def __init__(self, feature_config: Dict):
24
+ self.config = feature_config
25
+ self.fingerprint_type = feature_config.get("type", "ecfp").lower()
26
+ self.radius = feature_config.get("radius", 2)
27
+ self.n_bits = feature_config.get("n_bits", 1024)
28
+ self.map4_dim = feature_config.get("map4_dim", 1024)
29
+ self.use_counts = feature_config.get("use_counts", False)
30
+ cache_dir = feature_config.get("cache_dir")
31
+ self.cache_dir = Path(cache_dir) if cache_dir else None
32
+ if self.cache_dir:
33
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
34
+
35
+ def featurize_dataframe(self, df: pd.DataFrame, split_name: str):
36
+ cache_payload = self._load_cache(split_name)
37
+ if cache_payload is not None:
38
+ mask = cache_payload["mask"]
39
+ canonical_smiles = cache_payload["canonical_smiles"].tolist()
40
+ features = cache_payload["features"]
41
+ clean_df = filter_dataframe_by_mask(df, mask, canonical_smiles)
42
+ return clean_df, features
43
+
44
+ batch = standardize_smiles(df["smiles"].tolist())
45
+ clean_df = filter_dataframe_by_mask(df, batch.mask, batch.canonical_smiles)
46
+ features = self._compute_fingerprints(batch.mols)
47
+
48
+ self._write_cache(split_name, batch.mask, batch.canonical_smiles, features)
49
+ return clean_df, features
50
+
51
+ def featurize_smiles(self, smiles: Sequence[str]) -> tuple[MoleculeBatch, np.ndarray]:
52
+ batch = standardize_smiles(smiles)
53
+ features = self._compute_fingerprints(batch.mols)
54
+ return batch, features
55
+
56
+ def _cache_path(self, split_name: str) -> Path | None:
57
+ if self.cache_dir is None:
58
+ return None
59
+ return self.cache_dir / f"{split_name}_{self.fingerprint_type}.npz"
60
+
61
+ def _load_cache(self, split_name: str):
62
+ cache_path = self._cache_path(split_name)
63
+ if cache_path is None or not cache_path.exists():
64
+ return None
65
+ return np.load(cache_path, allow_pickle=True)
66
+
67
+ def _write_cache(self, split_name: str, mask, canonical_smiles, features):
68
+ cache_path = self._cache_path(split_name)
69
+ if cache_path is None:
70
+ return
71
+ np.savez(
72
+ cache_path,
73
+ mask=mask,
74
+ canonical_smiles=np.array(canonical_smiles, dtype=object),
75
+ features=features,
76
+ )
77
+
78
+ def _compute_fingerprints(self, mols):
79
+ if not mols:
80
+ dim = self._fingerprint_dimension()
81
+ return np.zeros((0, dim), dtype=np.float32)
82
+
83
+ if self.fingerprint_type == "ecfp":
84
+ return self._compute_ecfp(mols)
85
+ if self.fingerprint_type == "map4":
86
+ return self._compute_map4(mols)
87
+ raise ValueError(f"Unsupported fingerprint type: {self.fingerprint_type}")
88
+
89
+ def _fingerprint_dimension(self) -> int:
90
+ if self.fingerprint_type == "map4":
91
+ return self.map4_dim
92
+ return self.n_bits
93
+
94
+ def _compute_ecfp(self, mols):
95
+ fingerprints = np.zeros((len(mols), self.n_bits), dtype=np.float32)
96
+ for idx, mol in enumerate(mols):
97
+ if self.use_counts:
98
+ fp = AllChem.GetMorganFingerprint(mol, self.radius)
99
+ arr = np.zeros(self.n_bits, dtype=np.float32)
100
+ for bit, value in fp.GetNonzeroElements().items():
101
+ arr[bit % self.n_bits] += value
102
+ else:
103
+ bitvect = AllChem.GetMorganFingerprintAsBitVect(
104
+ mol,
105
+ self.radius,
106
+ nBits=self.n_bits,
107
+ )
108
+ arr = np.zeros(self.n_bits, dtype=np.float32)
109
+ DataStructs.ConvertToNumpyArray(bitvect, arr)
110
+ fingerprints[idx] = arr
111
+ return fingerprints
112
+
113
+ def _compute_map4(self, mols):
114
+ if MAP4Calculator is None:
115
+ raise ImportError(
116
+ "MAP4 fingerprint requested but the `map4` package is not installed. "
117
+ "Install it via `pip install map4` or switch features.type to 'ecfp'."
118
+ )
119
+
120
+ calc = MAP4Calculator(dimensions=self.map4_dim)
121
+ fingerprints = np.zeros((len(mols), self.map4_dim), dtype=np.float32)
122
+ for idx, mol in enumerate(mols):
123
+ vec = np.array(calc.calculate(mol), dtype=np.float32)
124
+ fingerprints[idx] = vec
125
+ return fingerprints
src/lightgbm_trainer.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Dict, Optional, Sequence
7
+
8
+ import joblib
9
+ import lightgbm as lgb
10
+ import numpy as np
11
+ import optuna
12
+ import pandas as pd
13
+ from sklearn.metrics import roc_auc_score
14
+
15
+ from .constants import TARGET_NAMES
16
+
17
+
18
+ @dataclass
19
+ class TaskTrainingOutput:
20
+ model: lgb.LGBMClassifier
21
+ val_auc: float
22
+ best_iteration: int
23
+ best_params: Dict
24
+
25
+
26
+ def _sample_hyperparams(trial: optuna.Trial, base_params: Dict) -> Dict:
27
+ params = dict(base_params)
28
+ params.update(
29
+ {
30
+ "learning_rate": trial.suggest_float("learning_rate", 1e-3, 0.2, log=True),
31
+ "num_leaves": trial.suggest_int("num_leaves", 16, 256, log=True),
32
+ "max_depth": trial.suggest_int("max_depth", -1, 12),
33
+ "min_child_samples": trial.suggest_int("min_child_samples", 10, 200),
34
+ "feature_fraction": trial.suggest_float("feature_fraction", 0.5, 1.0),
35
+ "bagging_fraction": trial.suggest_float("bagging_fraction", 0.5, 1.0),
36
+ "bagging_freq": trial.suggest_int("bagging_freq", 1, 10),
37
+ "reg_alpha": trial.suggest_float("reg_alpha", 1e-8, 10.0, log=True),
38
+ "reg_lambda": trial.suggest_float("reg_lambda", 1e-8, 10.0, log=True),
39
+ }
40
+ )
41
+ params.setdefault("objective", "binary")
42
+ params.setdefault("metric", "auc")
43
+ params.setdefault("verbosity", -1)
44
+ params.setdefault("boosting_type", "gbdt")
45
+ params.setdefault("n_jobs", -1)
46
+ return params
47
+
48
+
49
+ def train_lightgbm_task(
50
+ X_train: np.ndarray,
51
+ y_train: np.ndarray,
52
+ X_val: np.ndarray,
53
+ y_val: np.ndarray,
54
+ base_params: Dict,
55
+ boosting_rounds: int,
56
+ early_stopping_rounds: int,
57
+ n_trials: int,
58
+ seed: int,
59
+ ) -> Optional[TaskTrainingOutput]:
60
+ if len(np.unique(y_train)) < 2 or len(np.unique(y_val)) < 2:
61
+ return None
62
+
63
+ def objective(trial: optuna.Trial) -> float:
64
+ params = _sample_hyperparams(trial, base_params)
65
+ params["n_estimators"] = boosting_rounds
66
+ params["random_state"] = seed
67
+ model = lgb.LGBMClassifier(**params)
68
+ model.fit(
69
+ X_train,
70
+ y_train,
71
+ eval_set=[(X_val, y_val)],
72
+ eval_metric="auc",
73
+ callbacks=[
74
+ lgb.early_stopping(
75
+ early_stopping_rounds,
76
+ first_metric_only=True,
77
+ verbose=False,
78
+ )
79
+ ],
80
+ verbose=False,
81
+ )
82
+ best_iter = getattr(model, "best_iteration_", boosting_rounds)
83
+ preds = model.predict_proba(X_val, num_iteration=best_iter)[:, 1]
84
+ return float(roc_auc_score(y_val, preds))
85
+
86
+ study = optuna.create_study(direction="maximize")
87
+ study.optimize(objective, n_trials=n_trials, show_progress_bar=False)
88
+
89
+ best_params = _sample_hyperparams(study.best_trial, base_params)
90
+ best_params["n_estimators"] = boosting_rounds
91
+ best_params["random_state"] = seed
92
+
93
+ final_model = lgb.LGBMClassifier(**best_params)
94
+ final_model.fit(
95
+ X_train,
96
+ y_train,
97
+ eval_set=[(X_val, y_val)],
98
+ eval_metric="auc",
99
+ callbacks=[
100
+ lgb.early_stopping(
101
+ early_stopping_rounds,
102
+ first_metric_only=True,
103
+ verbose=False,
104
+ )
105
+ ],
106
+ verbose=False,
107
+ )
108
+
109
+ best_iteration = getattr(final_model, "best_iteration_", boosting_rounds)
110
+ val_preds = final_model.predict_proba(X_val, num_iteration=best_iteration)[:, 1]
111
+ val_auc = roc_auc_score(y_val, val_preds)
112
+
113
+ return TaskTrainingOutput(
114
+ model=final_model,
115
+ val_auc=float(val_auc),
116
+ best_iteration=int(best_iteration),
117
+ best_params=best_params,
118
+ )
119
+
120
+
121
+ def save_stage_metrics(metrics: Dict, path: Path):
122
+ path.parent.mkdir(parents=True, exist_ok=True)
123
+ with path.open("w", encoding="utf-8") as f:
124
+ json.dump(metrics, f, indent=2)
125
+
126
+
127
+ def train_stage_one_models(
128
+ train_features: np.ndarray,
129
+ val_features: Optional[np.ndarray],
130
+ train_df: pd.DataFrame,
131
+ val_df: Optional[pd.DataFrame],
132
+ config: Dict,
133
+ checkpoint_dir: Path,
134
+ target_names: Sequence[str] = TARGET_NAMES,
135
+ ) -> Dict:
136
+ stage_dir = checkpoint_dir / "stage1"
137
+ stage_dir.mkdir(parents=True, exist_ok=True)
138
+
139
+ training_cfg = config.get("training", {})
140
+ base_params = training_cfg.get("lightgbm_params", {})
141
+ n_trials = training_cfg.get("optuna_trials", 40)
142
+ boosting_rounds = training_cfg.get("boosting_rounds", 1500)
143
+ early_stopping = training_cfg.get("early_stopping_rounds", 100)
144
+ seed = config.get("seed", 42)
145
+
146
+ n_train = len(train_df)
147
+ n_tasks = len(target_names)
148
+
149
+ train_preds = np.full((n_train, n_tasks), 0.5, dtype=np.float32)
150
+ val_preds = (
151
+ np.full((len(val_df), n_tasks), 0.5, dtype=np.float32)
152
+ if val_df is not None and val_features is not None
153
+ else None
154
+ )
155
+
156
+ metrics: Dict[str, Dict] = {}
157
+ params_dump: Dict[str, Dict] = {}
158
+
159
+ for task_idx, task_name in enumerate(target_names):
160
+ train_mask = train_df[task_name].notna().values
161
+ if val_df is None or val_features is None:
162
+ metrics[task_name] = {"status": "skipped", "reason": "missing validation split"}
163
+ continue
164
+
165
+ val_mask = val_df[task_name].notna().values
166
+ if train_mask.sum() < 2 or val_mask.sum() < 2:
167
+ metrics[task_name] = {"status": "skipped", "reason": "insufficient labeled data"}
168
+ continue
169
+
170
+ X_train_task = train_features[train_mask]
171
+ y_train_task = train_df.loc[train_mask, task_name].astype(float).values
172
+ X_val_task = val_features[val_mask]
173
+ y_val_task = val_df.loc[val_mask, task_name].astype(float).values
174
+
175
+ if len(np.unique(y_train_task)) < 2 or len(np.unique(y_val_task)) < 2:
176
+ metrics[task_name] = {"status": "skipped", "reason": "single-class labels"}
177
+ continue
178
+
179
+ task_result = train_lightgbm_task(
180
+ X_train_task,
181
+ y_train_task,
182
+ X_val_task,
183
+ y_val_task,
184
+ base_params=base_params,
185
+ boosting_rounds=boosting_rounds,
186
+ early_stopping_rounds=early_stopping,
187
+ n_trials=n_trials,
188
+ seed=seed,
189
+ )
190
+
191
+ if task_result is None:
192
+ metrics[task_name] = {"status": "skipped", "reason": "training failed"}
193
+ continue
194
+
195
+ model = task_result.model
196
+ best_iter = task_result.best_iteration
197
+
198
+ model_path = stage_dir / f"{task_name}.pkl"
199
+ joblib.dump(model, model_path)
200
+
201
+ params_dump[task_name] = {
202
+ **task_result.best_params,
203
+ "best_iteration": best_iter,
204
+ "val_auc": task_result.val_auc,
205
+ }
206
+
207
+ full_train_preds = model.predict_proba(
208
+ train_features,
209
+ num_iteration=best_iter,
210
+ )[:, 1]
211
+ train_preds[:, task_idx] = full_train_preds.astype(np.float32)
212
+
213
+ if val_preds is not None:
214
+ full_val_preds = model.predict_proba(
215
+ val_features,
216
+ num_iteration=best_iter,
217
+ )[:, 1]
218
+ val_preds[:, task_idx] = full_val_preds.astype(np.float32)
219
+
220
+ metrics[task_name] = {
221
+ "val_auc": task_result.val_auc,
222
+ "n_train_samples": int(train_mask.sum()),
223
+ "n_val_samples": int(val_mask.sum()),
224
+ }
225
+
226
+ save_stage_metrics(metrics, checkpoint_dir / "metrics_stage1.json")
227
+ params_path = checkpoint_dir / "stage1_params.json"
228
+ with params_path.open("w", encoding="utf-8") as f:
229
+ json.dump(params_dump, f, indent=2)
230
+
231
+ return {
232
+ "train_full": train_preds,
233
+ "val_full": val_preds,
234
+ "metrics": metrics,
235
+ }
src/model.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch_geometric.nn import GINConv, global_add_pool, global_mean_pool
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+
7
+
8
+ class GIN(torch.nn.Module):
9
+ def __init__(self, num_features, num_classes, dropout, hidden_dim=128, num_layers=5, add_or_mean="add"):
10
+ super().__init__()
11
+ self.num_layers = num_layers
12
+ self.hidden_dim = hidden_dim
13
+ self.add_or_mean = add_or_mean
14
+ self.dropout = dropout
15
+
16
+ self.conv_layers = nn.ModuleList()
17
+
18
+ # input features → hidden_dim
19
+ mlp = nn.Sequential(
20
+ nn.Linear(num_features, hidden_dim),
21
+ nn.ReLU(),
22
+ nn.Linear(hidden_dim, hidden_dim),
23
+ nn.BatchNorm1d(hidden_dim)
24
+ )
25
+ self.conv_layers.append(GINConv(mlp, train_eps=True))
26
+
27
+ # hidden GIN layers
28
+ for _ in range(num_layers - 1):
29
+ mlp = nn.Sequential(
30
+ nn.Linear(hidden_dim, hidden_dim),
31
+ nn.ReLU(),
32
+ nn.Linear(hidden_dim, hidden_dim),
33
+ nn.BatchNorm1d(hidden_dim)
34
+ )
35
+ self.conv_layers.append(GINConv(mlp, train_eps=True))
36
+
37
+ # Final classifier (after pooling)
38
+ self.fc = nn.Linear(hidden_dim, num_classes)
39
+
40
+ def forward(self, x, edge_index, batch):
41
+ for conv in self.conv_layers:
42
+ x = conv(x, edge_index)
43
+ x = F.relu(x)
44
+ x = F.dropout(x, p=self.dropout, training=self.training)
45
+ # Pool to get graph-level representation
46
+ if self.add_or_mean == "mean":
47
+ x = global_mean_pool(x, batch)
48
+ elif self.add_or_mean == "add":
49
+ x = global_add_pool(x, batch)
50
+
51
+ x = F.dropout(x, p=0.5, training=self.training)
52
+ return self.fc(x)
src/preprocess.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Dict, List, Sequence
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ from datasets import load_dataset
9
+ from rdkit import Chem
10
+ from rdkit.Chem.MolStandardize import rdMolStandardize
11
+
12
+ from .constants import CANONICAL_SMILES_COLUMN
13
+
14
+ @dataclass
15
+ class MoleculeBatch:
16
+ mols: List[Chem.Mol]
17
+ mask: np.ndarray
18
+ canonical_smiles: List[str]
19
+
20
+
21
+ def load_tox21_dataset(token: str | None, dataset_name: str) -> Dict[str, pd.DataFrame]:
22
+ """Load dataset splits from Hugging Face into pandas DataFrames."""
23
+ dataset = load_dataset(dataset_name, token=token)
24
+ splits: Dict[str, pd.DataFrame] = {}
25
+ for split_name in dataset.keys():
26
+ splits[split_name] = dataset[split_name].to_pandas()
27
+ return splits
28
+
29
+
30
+ def standardize_smiles(smiles: Sequence[str]) -> MoleculeBatch:
31
+ """Standardize SMILES strings and return RDKit molecules with canonical SMILES."""
32
+ tautomer_enumerator = rdMolStandardize.TautomerEnumerator()
33
+ cleanup_params = rdMolStandardize.CleanupParameters()
34
+
35
+ mols: List[Chem.Mol] = []
36
+ canonical_smiles: List[str] = []
37
+ mask = np.zeros(len(smiles), dtype=bool)
38
+
39
+ for idx, smi in enumerate(smiles):
40
+ try:
41
+ mol = Chem.MolFromSmiles(smi)
42
+ if mol is None:
43
+ continue
44
+
45
+ mol = rdMolStandardize.Cleanup(mol, cleanup_params)
46
+ mol = tautomer_enumerator.Canonicalize(mol)
47
+ canonical = Chem.MolToSmiles(mol)
48
+ mol = Chem.MolFromSmiles(canonical)
49
+ if mol is None:
50
+ continue
51
+
52
+ mols.append(mol)
53
+ canonical_smiles.append(canonical)
54
+ mask[idx] = True
55
+ except Exception:
56
+ continue
57
+
58
+ return MoleculeBatch(mols=mols, mask=mask, canonical_smiles=canonical_smiles)
59
+
60
+
61
+ def filter_dataframe_by_mask(df: pd.DataFrame, mask: np.ndarray, canonical_smiles: Sequence[str]) -> pd.DataFrame:
62
+ """Apply mask to dataframe and append canonical SMILES column."""
63
+ clean_df = df.loc[mask].copy().reset_index(drop=True)
64
+ clean_df[CANONICAL_SMILES_COLUMN] = canonical_smiles
65
+ return clean_df
src/seed.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import numpy as np
5
+
6
+
7
+ def set_seed(seed: int = 42):
8
+ os.environ["PYTHONHASHSEED"] = str(seed)
9
+ random.seed(seed)
10
+ np.random.seed(seed)
src/stage_two.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Dict, Optional, Sequence
5
+
6
+ import joblib
7
+ import numpy as np
8
+ import pandas as pd
9
+ from sklearn.metrics import roc_auc_score
10
+
11
+ from .constants import TARGET_NAMES
12
+ from .lightgbm_trainer import save_stage_metrics, train_lightgbm_task
13
+
14
+
15
+ def _build_augmented_matrix(base_features: np.ndarray, prediction_matrix: np.ndarray, target_idx: int) -> np.ndarray:
16
+ mask = np.ones(prediction_matrix.shape[1], dtype=bool)
17
+ mask[target_idx] = False
18
+ return np.concatenate([base_features, prediction_matrix[:, mask]], axis=1)
19
+
20
+
21
+ def train_stage_two_models(
22
+ train_features: np.ndarray,
23
+ val_features: Optional[np.ndarray],
24
+ train_df: pd.DataFrame,
25
+ val_df: Optional[pd.DataFrame],
26
+ config: Dict,
27
+ checkpoint_dir: Path,
28
+ stage1_train_preds: np.ndarray,
29
+ stage1_val_preds: Optional[np.ndarray],
30
+ target_names: Sequence[str] = TARGET_NAMES,
31
+ ) -> Dict:
32
+ training_cfg = config.get("training", {})
33
+ base_params = training_cfg.get("lightgbm_params", {})
34
+ n_trials = training_cfg.get("optuna_trials", 40)
35
+ boosting_rounds = training_cfg.get("boosting_rounds", 1500)
36
+ early_stopping = training_cfg.get("early_stopping_rounds", 100)
37
+ seed = config.get("seed", 42)
38
+
39
+ stage_dir = checkpoint_dir / "stage2"
40
+ stage_dir.mkdir(parents=True, exist_ok=True)
41
+
42
+ n_train = len(train_df)
43
+ n_val = len(val_df) if val_df is not None else 0
44
+
45
+ metrics: Dict[str, Dict] = {}
46
+
47
+ for task_idx, task_name in enumerate(target_names):
48
+ mask = train_df[task_name].notna().values
49
+ if mask.sum() == 0:
50
+ metrics[task_name] = {"status": "skipped", "reason": "no labels"}
51
+ continue
52
+
53
+ augmented_train_matrix = _build_augmented_matrix(
54
+ train_features[mask],
55
+ stage1_train_preds[mask],
56
+ task_idx,
57
+ )
58
+ y_train = train_df.loc[mask, task_name].astype(float).values
59
+
60
+ if (
61
+ val_features is None
62
+ or val_df is None
63
+ or stage1_val_preds is None
64
+ or val_df[task_name].notna().sum() < 2
65
+ ):
66
+ metrics[task_name] = {"status": "skipped", "reason": "missing validation data"}
67
+ continue
68
+
69
+ val_mask = val_df[task_name].notna().values
70
+ augmented_val_matrix = _build_augmented_matrix(
71
+ val_features[val_mask],
72
+ stage1_val_preds[val_mask],
73
+ task_idx,
74
+ )
75
+ y_val = val_df.loc[val_mask, task_name].astype(float).values
76
+
77
+ if len(np.unique(y_val)) < 2 or len(np.unique(y_train)) < 2:
78
+ metrics[task_name] = {"status": "skipped", "reason": "single-class labels"}
79
+ continue
80
+
81
+ task_result = train_lightgbm_task(
82
+ augmented_train_matrix,
83
+ y_train,
84
+ augmented_val_matrix,
85
+ y_val,
86
+ base_params=base_params,
87
+ boosting_rounds=boosting_rounds,
88
+ early_stopping_rounds=early_stopping,
89
+ n_trials=n_trials,
90
+ seed=seed,
91
+ )
92
+
93
+ if task_result is None:
94
+ metrics[task_name] = {"status": "skipped", "reason": "training failed"}
95
+ continue
96
+
97
+ model_path = stage_dir / f"{task_name}.pkl"
98
+ joblib.dump(task_result.model, model_path)
99
+
100
+ metrics[task_name] = {
101
+ "val_auc": task_result.val_auc,
102
+ "best_iteration": int(task_result.best_iteration),
103
+ }
104
+
105
+ save_stage_metrics(metrics, checkpoint_dir / "metrics_stage2.json")
106
+ return {"metrics": metrics}
src/train_evaluate.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from sklearn.metrics import roc_auc_score
5
+
6
+ def masked_bce_loss(logits, labels, mask):
7
+ """
8
+ logits: [batch_size, num_classes] (raw outputs)
9
+ labels: [batch_size, num_classes] (0/1 with filler)
10
+ mask: [batch_size, num_classes] (True if label is valid)
11
+ """
12
+ criterion = nn.BCEWithLogitsLoss(reduction="none")
13
+ loss_raw = criterion(logits, labels)
14
+ loss = (loss_raw * mask.float()).sum() / mask.float().sum()
15
+ return loss
16
+
17
+ def train_model(model, loader, optimizer, device):
18
+ model.train()
19
+ total_loss = 0
20
+ for batch in loader:
21
+ batch = batch.to(device)
22
+
23
+ optimizer.zero_grad()
24
+ out = model(batch.x, batch.edge_index, batch.batch) # [num_graphs, num_classes]
25
+
26
+ loss = masked_bce_loss(out, batch.y, batch.mask)
27
+ loss.backward()
28
+ optimizer.step()
29
+
30
+ total_loss += loss.item() * batch.num_graphs
31
+ return total_loss / len(loader.dataset)
32
+
33
+
34
+ @torch.no_grad()
35
+ def evaluate(model, loader, device):
36
+ model.eval()
37
+ total_loss = 0
38
+ for batch in loader:
39
+ batch = batch.to(device)
40
+ out = model(batch.x, batch.edge_index, batch.batch)
41
+ loss = masked_bce_loss(out, batch.y, batch.mask)
42
+ total_loss += loss.item() * batch.num_graphs
43
+ return total_loss / len(loader.dataset)
44
+
45
+
46
+ @torch.no_grad()
47
+ def compute_roc_auc(model, loader, device):
48
+ model.eval()
49
+ y_true, y_pred, y_mask = [], [], []
50
+
51
+ for batch in loader:
52
+ batch = batch.to(device)
53
+ out = model(batch.x, batch.edge_index, batch.batch)
54
+
55
+ # Store predictions (sigmoid → probabilities)
56
+ y_pred.append(torch.sigmoid(out).cpu())
57
+ y_true.append(batch.y.cpu())
58
+ y_mask.append(batch.mask.cpu())
59
+
60
+ # Concatenate across all batches
61
+ y_true = torch.cat(y_true, dim=0).numpy()
62
+ y_pred = torch.cat(y_pred, dim=0).numpy()
63
+ y_mask = torch.cat(y_mask, dim=0).numpy()
64
+
65
+ auc_list = []
66
+ for i in range(y_true.shape[1]): # per label
67
+ mask_i = y_mask[:, i].astype(bool)
68
+ if mask_i.sum() > 0: # at least one valid label
69
+ try:
70
+ auc = roc_auc_score(y_true[mask_i, i], y_pred[mask_i, i])
71
+ auc_list.append(auc)
72
+ except ValueError:
73
+ # happens if only one class present (all 0 or all 1)
74
+ pass
75
+
76
+ return np.mean(auc_list) if len(auc_list) > 0 else float("nan")
77
+
78
+ @torch.no_grad()
79
+ def compute_roc_auc_avg_and_per_class(model, loader, device):
80
+ model.eval()
81
+ y_true, y_pred, y_mask = [], [], []
82
+
83
+ with torch.no_grad():
84
+ for batch in loader:
85
+ batch = batch.to(device)
86
+ out = model(batch.x, batch.edge_index, batch.batch)
87
+
88
+ # Store predictions (sigmoid → probabilities)
89
+ y_pred.append(torch.sigmoid(out).cpu())
90
+ y_true.append(batch.y.cpu())
91
+ y_mask.append(batch.mask.cpu())
92
+
93
+ # Concatenate across all batches
94
+ y_true = torch.cat(y_true, dim=0).numpy()
95
+ y_pred = torch.cat(y_pred, dim=0).numpy()
96
+ y_mask = torch.cat(y_mask, dim=0).numpy()
97
+
98
+ # Compute AUC per class
99
+ auc_list = []
100
+ for i in range(y_true.shape[1]):
101
+ mask_i = y_mask[:, i].astype(bool)
102
+ if mask_i.sum() > 0:
103
+ try:
104
+ auc = roc_auc_score(y_true[mask_i, i], y_pred[mask_i, i])
105
+ except ValueError:
106
+ auc = np.nan # in case only one class present
107
+ else:
108
+ auc = np.nan
109
+ auc_list.append(auc)
110
+
111
+ # Convert to numpy array for easier manipulation
112
+ auc_array = np.array(auc_list, dtype=np.float32)
113
+ mean_auc = np.nanmean(auc_array) # overall mean ignoring NaNs
114
+
115
+ # Return both per-class and mean
116
+ return auc_array, mean_auc
train.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ from pathlib import Path
6
+ from typing import Dict
7
+
8
+ import numpy as np
9
+ from dotenv import load_dotenv
10
+
11
+ from src.constants import TARGET_NAMES
12
+ from src.features import FingerprintFeaturizer
13
+ from src.lightgbm_trainer import train_stage_one_models
14
+ from src.preprocess import load_tox21_dataset
15
+ from src.seed import set_seed
16
+ from src.stage_two import train_stage_two_models
17
+
18
+
19
+ def _default_checkpoint_dir(config: Dict) -> Path:
20
+ checkpoint_cfg = config.get("output", {})
21
+ checkpoint_dir = checkpoint_cfg.get("checkpoint_dir", "./checkpoints")
22
+ path = Path(checkpoint_dir)
23
+ path.mkdir(parents=True, exist_ok=True)
24
+ return path
25
+
26
+
27
+ def train(config: Dict):
28
+ load_dotenv()
29
+ set_seed(config.get("seed", 42))
30
+ token = os.getenv("TOKEN")
31
+
32
+ dataset_cfg = config.get("dataset", {})
33
+ dataset_name = dataset_cfg.get("name", "ml-jku/tox21")
34
+ splits = load_tox21_dataset(token, dataset_name)
35
+
36
+ if "train" not in splits or "validation" not in splits:
37
+ raise ValueError("Dataset must provide 'train' and 'validation' splits.")
38
+
39
+ featurizer = FingerprintFeaturizer(config.get("features", {}))
40
+ train_df, train_features = featurizer.featurize_dataframe(splits["train"], "train")
41
+ val_df, val_features = featurizer.featurize_dataframe(splits["validation"], "validation")
42
+
43
+ checkpoint_dir = _default_checkpoint_dir(config)
44
+ cache_dir = checkpoint_dir / "cache"
45
+ cache_dir.mkdir(parents=True, exist_ok=True)
46
+
47
+ print("==== Stage 1: Training baseline LightGBM models ====")
48
+ stage1_artifacts = train_stage_one_models(
49
+ train_features,
50
+ val_features,
51
+ train_df,
52
+ val_df,
53
+ config,
54
+ checkpoint_dir,
55
+ target_names=TARGET_NAMES,
56
+ )
57
+
58
+ stage1_train_full = stage1_artifacts["train_full"]
59
+ stage1_val_full = stage1_artifacts["val_full"]
60
+
61
+ np.savez(
62
+ cache_dir / "stage1_train_predictions.npz",
63
+ full=stage1_train_full,
64
+ target_names=np.array(TARGET_NAMES, dtype=object),
65
+ )
66
+ if stage1_val_full is not None:
67
+ np.savez(
68
+ cache_dir / "stage1_validation_predictions.npz",
69
+ full=stage1_val_full,
70
+ target_names=np.array(TARGET_NAMES, dtype=object),
71
+ )
72
+
73
+ stage2_metrics = None
74
+ multitask_cfg = config.get("multitask", {"enabled": False})
75
+ if multitask_cfg.get("enabled", False):
76
+ print("==== Stage 2: Training multitask-augmented LightGBM models ====")
77
+ stage2_artifacts = train_stage_two_models(
78
+ train_features,
79
+ val_features,
80
+ train_df,
81
+ val_df,
82
+ config,
83
+ checkpoint_dir,
84
+ stage1_train_full,
85
+ stage1_val_full,
86
+ target_names=TARGET_NAMES,
87
+ )
88
+ stage2_metrics = stage2_artifacts["metrics"]
89
+
90
+ stage2_entry = {
91
+ "enabled": bool(multitask_cfg.get("enabled", False)),
92
+ "model_dir": str(checkpoint_dir / "stage2") if stage2_metrics is not None else None,
93
+ "metrics": str(checkpoint_dir / "metrics_stage2.json") if stage2_metrics is not None else None,
94
+ }
95
+
96
+ manifest = {
97
+ "feature_config": config.get("features", {}),
98
+ "target_names": TARGET_NAMES,
99
+ "dataset": dataset_cfg,
100
+ "stage1": {
101
+ "model_dir": str(checkpoint_dir / "stage1"),
102
+ "metrics": str((checkpoint_dir / "metrics_stage1.json")),
103
+ },
104
+ "stage2": stage2_entry,
105
+ "multitask": multitask_cfg,
106
+ "seed": config.get("seed", 42),
107
+ }
108
+
109
+ manifest_path = checkpoint_dir / "training_manifest.json"
110
+ with manifest_path.open("w", encoding="utf-8") as f:
111
+ json.dump(manifest, f, indent=2)
112
+
113
+ print("Training complete.")
114
+
115
+
116
+ if __name__ == "__main__":
117
+ with open("./config/config.json", "r", encoding="utf-8") as f:
118
+ config = json.load(f)
119
+ train(config)