Daniel Probst commited on
Commit
444d15c
·
1 Parent(s): 9678908

Initial commit

Browse files
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM python:3.11
5
+
6
+ RUN useradd -m -u 1000 user
7
+ USER user
8
+ ENV PATH="/home/user/.local/bin:$PATH"
9
+
10
+ WORKDIR /app
11
+
12
+ COPY --chown=user ./requirements.txt requirements.txt
13
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
14
+
15
+ COPY --chown=user . /app
16
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
LICENSE ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution-NonCommercial 4.0 International Public
58
+ License
59
+
60
+ By exercising the Licensed Rights (defined below), You accept and agree
61
+ to be bound by the terms and conditions of this Creative Commons
62
+ Attribution-NonCommercial 4.0 International Public License ("Public
63
+ License"). To the extent this Public License may be interpreted as a
64
+ contract, You are granted the Licensed Rights in consideration of Your
65
+ acceptance of these terms and conditions, and the Licensor grants You
66
+ such rights in consideration of benefits the Licensor receives from
67
+ making the Licensed Material available under these terms and
68
+ conditions.
69
+
70
+
71
+ Section 1 -- Definitions.
72
+
73
+ a. Adapted Material means material subject to Copyright and Similar
74
+ Rights that is derived from or based upon the Licensed Material
75
+ and in which the Licensed Material is translated, altered,
76
+ arranged, transformed, or otherwise modified in a manner requiring
77
+ permission under the Copyright and Similar Rights held by the
78
+ Licensor. For purposes of this Public License, where the Licensed
79
+ Material is a musical work, performance, or sound recording,
80
+ Adapted Material is always produced where the Licensed Material is
81
+ synched in timed relation with a moving image.
82
+
83
+ b. Adapter's License means the license You apply to Your Copyright
84
+ and Similar Rights in Your contributions to Adapted Material in
85
+ accordance with the terms and conditions of this Public License.
86
+
87
+ c. Copyright and Similar Rights means copyright and/or similar rights
88
+ closely related to copyright including, without limitation,
89
+ performance, broadcast, sound recording, and Sui Generis Database
90
+ Rights, without regard to how the rights are labeled or
91
+ categorized. For purposes of this Public License, the rights
92
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
93
+ Rights.
94
+ d. Effective Technological Measures means those measures that, in the
95
+ absence of proper authority, may not be circumvented under laws
96
+ fulfilling obligations under Article 11 of the WIPO Copyright
97
+ Treaty adopted on December 20, 1996, and/or similar international
98
+ agreements.
99
+
100
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
101
+ any other exception or limitation to Copyright and Similar Rights
102
+ that applies to Your use of the Licensed Material.
103
+
104
+ f. Licensed Material means the artistic or literary work, database,
105
+ or other material to which the Licensor applied this Public
106
+ License.
107
+
108
+ g. Licensed Rights means the rights granted to You subject to the
109
+ terms and conditions of this Public License, which are limited to
110
+ all Copyright and Similar Rights that apply to Your use of the
111
+ Licensed Material and that the Licensor has authority to license.
112
+
113
+ h. Licensor means the individual(s) or entity(ies) granting rights
114
+ under this Public License.
115
+
116
+ i. NonCommercial means not primarily intended for or directed towards
117
+ commercial advantage or monetary compensation. For purposes of
118
+ this Public License, the exchange of the Licensed Material for
119
+ other material subject to Copyright and Similar Rights by digital
120
+ file-sharing or similar means is NonCommercial provided there is
121
+ no payment of monetary compensation in connection with the
122
+ exchange.
123
+
124
+ j. Share means to provide material to the public by any means or
125
+ process that requires permission under the Licensed Rights, such
126
+ as reproduction, public display, public performance, distribution,
127
+ dissemination, communication, or importation, and to make material
128
+ available to the public including in ways that members of the
129
+ public may access the material from a place and at a time
130
+ individually chosen by them.
131
+
132
+ k. Sui Generis Database Rights means rights other than copyright
133
+ resulting from Directive 96/9/EC of the European Parliament and of
134
+ the Council of 11 March 1996 on the legal protection of databases,
135
+ as amended and/or succeeded, as well as other essentially
136
+ equivalent rights anywhere in the world.
137
+
138
+ l. You means the individual or entity exercising the Licensed Rights
139
+ under this Public License. Your has a corresponding meaning.
140
+
141
+
142
+ Section 2 -- Scope.
143
+
144
+ a. License grant.
145
+
146
+ 1. Subject to the terms and conditions of this Public License,
147
+ the Licensor hereby grants You a worldwide, royalty-free,
148
+ non-sublicensable, non-exclusive, irrevocable license to
149
+ exercise the Licensed Rights in the Licensed Material to:
150
+
151
+ a. reproduce and Share the Licensed Material, in whole or
152
+ in part, for NonCommercial purposes only; and
153
+
154
+ b. produce, reproduce, and Share Adapted Material for
155
+ NonCommercial purposes only.
156
+
157
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
158
+ Exceptions and Limitations apply to Your use, this Public
159
+ License does not apply, and You do not need to comply with
160
+ its terms and conditions.
161
+
162
+ 3. Term. The term of this Public License is specified in Section
163
+ 6(a).
164
+
165
+ 4. Media and formats; technical modifications allowed. The
166
+ Licensor authorizes You to exercise the Licensed Rights in
167
+ all media and formats whether now known or hereafter created,
168
+ and to make technical modifications necessary to do so. The
169
+ Licensor waives and/or agrees not to assert any right or
170
+ authority to forbid You from making technical modifications
171
+ necessary to exercise the Licensed Rights, including
172
+ technical modifications necessary to circumvent Effective
173
+ Technological Measures. For purposes of this Public License,
174
+ simply making modifications authorized by this Section 2(a)
175
+ (4) never produces Adapted Material.
176
+
177
+ 5. Downstream recipients.
178
+
179
+ a. Offer from the Licensor -- Licensed Material. Every
180
+ recipient of the Licensed Material automatically
181
+ receives an offer from the Licensor to exercise the
182
+ Licensed Rights under the terms and conditions of this
183
+ Public License.
184
+
185
+ b. No downstream restrictions. You may not offer or impose
186
+ any additional or different terms or conditions on, or
187
+ apply any Effective Technological Measures to, the
188
+ Licensed Material if doing so restricts exercise of the
189
+ Licensed Rights by any recipient of the Licensed
190
+ Material.
191
+
192
+ 6. No endorsement. Nothing in this Public License constitutes or
193
+ may be construed as permission to assert or imply that You
194
+ are, or that Your use of the Licensed Material is, connected
195
+ with, or sponsored, endorsed, or granted official status by,
196
+ the Licensor or others designated to receive attribution as
197
+ provided in Section 3(a)(1)(A)(i).
198
+
199
+ b. Other rights.
200
+
201
+ 1. Moral rights, such as the right of integrity, are not
202
+ licensed under this Public License, nor are publicity,
203
+ privacy, and/or other similar personality rights; however, to
204
+ the extent possible, the Licensor waives and/or agrees not to
205
+ assert any such rights held by the Licensor to the limited
206
+ extent necessary to allow You to exercise the Licensed
207
+ Rights, but not otherwise.
208
+
209
+ 2. Patent and trademark rights are not licensed under this
210
+ Public License.
211
+
212
+ 3. To the extent possible, the Licensor waives any right to
213
+ collect royalties from You for the exercise of the Licensed
214
+ Rights, whether directly or through a collecting society
215
+ under any voluntary or waivable statutory or compulsory
216
+ licensing scheme. In all other cases the Licensor expressly
217
+ reserves any right to collect such royalties, including when
218
+ the Licensed Material is used other than for NonCommercial
219
+ purposes.
220
+
221
+
222
+ Section 3 -- License Conditions.
223
+
224
+ Your exercise of the Licensed Rights is expressly made subject to the
225
+ following conditions.
226
+
227
+ a. Attribution.
228
+
229
+ 1. If You Share the Licensed Material (including in modified
230
+ form), You must:
231
+
232
+ a. retain the following if it is supplied by the Licensor
233
+ with the Licensed Material:
234
+
235
+ i. identification of the creator(s) of the Licensed
236
+ Material and any others designated to receive
237
+ attribution, in any reasonable manner requested by
238
+ the Licensor (including by pseudonym if
239
+ designated);
240
+
241
+ ii. a copyright notice;
242
+
243
+ iii. a notice that refers to this Public License;
244
+
245
+ iv. a notice that refers to the disclaimer of
246
+ warranties;
247
+
248
+ v. a URI or hyperlink to the Licensed Material to the
249
+ extent reasonably practicable;
250
+
251
+ b. indicate if You modified the Licensed Material and
252
+ retain an indication of any previous modifications; and
253
+
254
+ c. indicate the Licensed Material is licensed under this
255
+ Public License, and include the text of, or the URI or
256
+ hyperlink to, this Public License.
257
+
258
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
259
+ reasonable manner based on the medium, means, and context in
260
+ which You Share the Licensed Material. For example, it may be
261
+ reasonable to satisfy the conditions by providing a URI or
262
+ hyperlink to a resource that includes the required
263
+ information.
264
+
265
+ 3. If requested by the Licensor, You must remove any of the
266
+ information required by Section 3(a)(1)(A) to the extent
267
+ reasonably practicable.
268
+
269
+ 4. If You Share Adapted Material You produce, the Adapter's
270
+ License You apply must not prevent recipients of the Adapted
271
+ Material from complying with this Public License.
272
+
273
+
274
+ Section 4 -- Sui Generis Database Rights.
275
+
276
+ Where the Licensed Rights include Sui Generis Database Rights that
277
+ apply to Your use of the Licensed Material:
278
+
279
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
280
+ to extract, reuse, reproduce, and Share all or a substantial
281
+ portion of the contents of the database for NonCommercial purposes
282
+ only;
283
+
284
+ b. if You include all or a substantial portion of the database
285
+ contents in a database in which You have Sui Generis Database
286
+ Rights, then the database in which You have Sui Generis Database
287
+ Rights (but not its individual contents) is Adapted Material; and
288
+
289
+ c. You must comply with the conditions in Section 3(a) if You Share
290
+ all or a substantial portion of the contents of the database.
291
+
292
+ For the avoidance of doubt, this Section 4 supplements and does not
293
+ replace Your obligations under this Public License where the Licensed
294
+ Rights include other Copyright and Similar Rights.
295
+
296
+
297
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
298
+
299
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
300
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
301
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
302
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
303
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
304
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
305
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
306
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
307
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
308
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
309
+
310
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
311
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
312
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
313
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
314
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
315
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
316
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
317
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
318
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
319
+
320
+ c. The disclaimer of warranties and limitation of liability provided
321
+ above shall be interpreted in a manner that, to the extent
322
+ possible, most closely approximates an absolute disclaimer and
323
+ waiver of all liability.
324
+
325
+
326
+ Section 6 -- Term and Termination.
327
+
328
+ a. This Public License applies for the term of the Copyright and
329
+ Similar Rights licensed here. However, if You fail to comply with
330
+ this Public License, then Your rights under this Public License
331
+ terminate automatically.
332
+
333
+ b. Where Your right to use the Licensed Material has terminated under
334
+ Section 6(a), it reinstates:
335
+
336
+ 1. automatically as of the date the violation is cured, provided
337
+ it is cured within 30 days of Your discovery of the
338
+ violation; or
339
+
340
+ 2. upon express reinstatement by the Licensor.
341
+
342
+ For the avoidance of doubt, this Section 6(b) does not affect any
343
+ right the Licensor may have to seek remedies for Your violations
344
+ of this Public License.
345
+
346
+ c. For the avoidance of doubt, the Licensor may also offer the
347
+ Licensed Material under separate terms or conditions or stop
348
+ distributing the Licensed Material at any time; however, doing so
349
+ will not terminate this Public License.
350
+
351
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
352
+ License.
353
+
354
+
355
+ Section 7 -- Other Terms and Conditions.
356
+
357
+ a. The Licensor shall not be bound by any additional or different
358
+ terms or conditions communicated by You unless expressly agreed.
359
+
360
+ b. Any arrangements, understandings, or agreements regarding the
361
+ Licensed Material not stated herein are separate from and
362
+ independent of the terms and conditions of this Public License.
363
+
364
+
365
+ Section 8 -- Interpretation.
366
+
367
+ a. For the avoidance of doubt, this Public License does not, and
368
+ shall not be interpreted to, reduce, limit, restrict, or impose
369
+ conditions on any use of the Licensed Material that could lawfully
370
+ be made without permission under this Public License.
371
+
372
+ b. To the extent possible, if any provision of this Public License is
373
+ deemed unenforceable, it shall be automatically reformed to the
374
+ minimum extent necessary to make it enforceable. If the provision
375
+ cannot be reformed, it shall be severed from this Public License
376
+ without affecting the enforceability of the remaining terms and
377
+ conditions.
378
+
379
+ c. No term or condition of this Public License will be waived and no
380
+ failure to comply consented to unless expressly agreed to by the
381
+ Licensor.
382
+
383
+ d. Nothing in this Public License constitutes or may be interpreted
384
+ as a limitation upon, or waiver of, any privileges and immunities
385
+ that apply to the Licensor or You, including from the legal
386
+ processes of any jurisdiction or authority.
387
+
388
+ =======================================================================
389
+
390
+ Creative Commons is not a party to its public
391
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
392
+ its public licenses to material it publishes and in those instances
393
+ will be considered the “Licensor.” The text of the Creative Commons
394
+ public licenses is dedicated to the public domain under the CC0 Public
395
+ Domain Dedication. Except for the limited purpose of indicating that
396
+ material is shared under a Creative Commons public license or as
397
+ otherwise permitted by the Creative Commons policies published at
398
+ creativecommons.org/policies, Creative Commons does not authorize the
399
+ use of the trademark "Creative Commons" or any other trademark or logo
400
+ of Creative Commons without its prior written consent including,
401
+ without limitation, in connection with any unauthorized modifications
402
+ to any of its public licenses or any other arrangements,
403
+ understandings, or agreements concerning use of licensed material. For
404
+ the avoidance of doubt, this paragraph does not form part of the
405
+ public licenses.
406
+
407
+ Creative Commons may be contacted at creativecommons.org.
MODEL_CARD.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model card - tox21_rf_classifier
2
+ ### Model details
3
+ - Model name: Random Forest Tox21 Baseline
4
+ - Developer: JKU Linz
5
+ - Paper URL: https://link.springer.com/article/10.1023/A:1010933404324
6
+ - Model type / architecture:
7
+ - Random Forest implemented using sklearn.RandomForestClassifier.
8
+ - Hyperparameters: [link to config](https://huggingface.co/spaces/ml-jku/tox21_rf_classifier/blob/main/config/config.json)
9
+ - A separate single-task RF is trained for each Tox21 target.
10
+ - Inference: Access via FastAPI. Upon a Tox21 prediction request, a target-specific RF
11
+ model is called separately for each target; outputs are collected across all single-task
12
+ models and returned.
13
+ - Model version: v0
14
+ - Model date: 14.10.2025
15
+ - Reproducibility: Code for full training is available and enables retraining from
16
+ scratch.
17
+
18
+ ### Intended use
19
+ This model serves as a baseline for evaluating and comparing toxicity prediction methods
20
+ across the 12 Tox21 pathway assays. It is not intended for clinical decision-making without
21
+ experimental validation.
22
+
23
+ ### Metric
24
+ Each Tox21 task is evaluated using the area under the receiver operating characteristic curve
25
+ (AUC). Overall performance is reported as the mean AUC across all tasks.
26
+
27
+ ### Training data
28
+ Tox21 training and validation sets.
29
+
30
+ ### Evaluation data
31
+ Tox21 test set.
README.md CHANGED
@@ -1,11 +1,117 @@
1
  ---
2
- title: Tox21 Mhfp
3
- emoji: 😻
4
- colorFrom: green
5
- colorTo: red
6
  sdk: docker
7
  pinned: false
8
- license: mit
 
9
  ---
10
 
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Tox21 Random Forest Classifier
3
+ emoji: 🚀
4
+ colorFrom: red
5
+ colorTo: purple
6
  sdk: docker
7
  pinned: false
8
+ license: cc-by-nc-4.0
9
+ short_description: Random Forest Baseline for Tox21
10
  ---
11
 
12
+ # Tox21 Random Forest Classifier
13
+
14
+ This repository hosts a Hugging Face Space that provides an API for submitting models to the [Tox21 Leaderboard](https://huggingface.co/spaces/ml-jku/tox21_leaderboard).
15
+
16
+ Here **Random Forest (RF)** models are trained on the Tox21 dataset, and the trained models are provided for inference. For each of the twelve toxic effects, a separate RF model is trained. The input to the model is a **SMILES** string of the small molecule, and the output are 12 numeric values for each of the toxic effects of the Tox21 dataset.
17
+
18
+ **Important:** For leaderboard submission, your Space needs to include training code. The file `train.py` should train the model using the config specified inside the `config/` folder and save the final model parameters into a file inside the `checkpoints/` folder. The model should be trained using the [Tox21_dataset](https://huggingface.co/datasets/ml-jku/tox21) provided on Hugging Face. The datasets can be loaded like this:
19
+ ```python
20
+ from datasets import load_dataset
21
+ ds = load_dataset("ml-jku/tox21", token=token)
22
+ train_df = ds["train"].to_pandas()
23
+ val_df = ds["validation"].to_pandas()
24
+ ```
25
+
26
+ Additionally, the Space needs to implement inference in the `predict()` function inside `predict.py`. The `predict()` function must keep the provided skeleton: it should take a list of SMILES strings as input and return a nested prediction dictionary as output, with SMILES as keys and dictionaries containing targetname-prediction pairs as values. Therefore, any preprocessing of SMILES strings must be executed on-the-fly during inference.
27
+
28
+ # Repository Structure
29
+ - `predict.py` - Defines the `predict()` function required by the leaderboard (entry point for inference).
30
+ - `app.py` - FastAPI application wrapper (can be used as-is).
31
+ - `preprocess.py` - preprocesses SMILES strings to generate feature descriptors and saves results as NPZ files in `data/`.
32
+ - `train.py` - trains and saves a model using the config in the `config/` folder.
33
+ - `config/` - the config file used by `train.py`.
34
+ - `logs/` - all the logs of `train.py`, the saved model, and predictions on the validation set.
35
+ - `data/` - RF uses numerical data. During preprocessing in `preprocess.py` two NPZ files containing molecule features are created and saved here.
36
+ - `checkpoints/` - the saved model that is used in `predict.py` is here.
37
+
38
+ - `src/` - Core model & preprocessing logic:
39
+ - `preprocess.py` - SMILES preprocessing logic
40
+ - `model.py` - RF model class with processing, saving and loading logic
41
+ - `utils.py` - utility functions
42
+
43
+ # Quickstart with Spaces
44
+
45
+ You can easily adapt this project in your own Hugging Face account:
46
+
47
+ - Open this Space on Hugging Face.
48
+
49
+ - Click "Duplicate this Space" (top-right corner).
50
+
51
+ - Modify `src/` for your preprocessing pipeline and model class
52
+
53
+ - Modify `predict()` inside `predict.py` to perform model inference while keeping the function skeleton unchanged to remain compatible with the leaderboard.
54
+
55
+ - Modify `train.py` and/or `preprocess.py` according to your model and preprocessing pipeline.
56
+
57
+ - Modify the file inside `config/` to contain all hyperparameters that are set in `train.py`.
58
+
59
+ That’s it, your model will be available as an API endpoint for the Tox21 Leaderboard.
60
+
61
+ # Installation
62
+ To run (and train) the random forest, clone the repository and install dependencies:
63
+
64
+ ```bash
65
+ git clone https://huggingface.co/spaces/ml-jku/tox21_rf_classifier
66
+ cd tox_21_rf_classifier
67
+
68
+ conda create -n tox21_rf_cls python=3.11
69
+ conda activate tox21_rf_cls
70
+ pip install -r requirements.txt
71
+ ```
72
+
73
+ # Training
74
+
75
+ To train the Random Forest model from scratch, run:
76
+
77
+ ```bash
78
+ python preprocess.py
79
+ python train.py
80
+ ```
81
+
82
+ These commands will:
83
+ 1. Load and preprocess the Tox21 training dataset
84
+ 2. Train a Random Forest classifier
85
+ 3. Store the resulting model in the `checkpoints/` directory.
86
+
87
+
88
+ # Inference
89
+
90
+ For inference, you only need `predict.py`.
91
+
92
+ Example usage inside Python:
93
+
94
+ ```python
95
+ from predict import predict
96
+
97
+ smiles_list = ["CCO", "c1ccccc1", "CC(=O)O"]
98
+ results = predict(smiles_list)
99
+
100
+ print(results)
101
+ ```
102
+
103
+ The output will be a nested dictionary in the format:
104
+
105
+ ```python
106
+ {
107
+ "CCO": {"target1": 0, "target2": 1, ..., "target12": 0},
108
+ "c1ccccc1": {"target1": 1, "target2": 0, ..., "target12": 1},
109
+ "CC(=O)O": {"target1": 0, "target2": 0, ..., "target12": 0}
110
+ }
111
+ ```
112
+
113
+ # Notes
114
+
115
+ - Adapting `predict.py`, `train.py`, `config/`, and `checkpoints/` is required for leaderboard submission.
116
+
117
+ - Preprocessing must be done inside `predict.py` not just `train.py`.
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is the main entry point for the FastAPI application.
3
+ The app handles the request to predict toxicity for a list of SMILES strings.
4
+ """
5
+
6
+ # ---------------------------------------------------------------------------------------
7
+ # Dependencies and global variable definition
8
+ import os
9
+ from typing import List, Dict, Optional
10
+ from fastapi import FastAPI, Header, HTTPException
11
+ from pydantic import BaseModel, Field
12
+
13
+ from predict import predict as predict_func
14
+
15
+ API_KEY = os.getenv("API_KEY") # set via Space Secrets
16
+
17
+
18
+ # ---------------------------------------------------------------------------------------
19
+ class Request(BaseModel):
20
+ smiles: List[str] = Field(min_items=1, max_items=1000)
21
+
22
+
23
+ class Response(BaseModel):
24
+ predictions: dict
25
+ model_info: Dict[str, str] = {}
26
+
27
+
28
+ app = FastAPI(title="toxicity-api")
29
+
30
+
31
+ @app.get("/")
32
+ def root():
33
+ return {
34
+ "message": "Toxicity Prediction API",
35
+ "endpoints": {
36
+ "/metadata": "GET - API metadata and capabilities",
37
+ "/healthz": "GET - Health check",
38
+ "/predict": "POST - Predict toxicity for SMILES",
39
+ },
40
+ "usage": "Send POST to /predict with {'smiles': ['your_smiles_here']} and Authorization header",
41
+ }
42
+
43
+
44
+ @app.get("/metadata")
45
+ def metadata():
46
+ return {
47
+ "name": "Tox21 Random Forest Classifier",
48
+ "version": "0.1.0",
49
+ "max_batch_size": 256,
50
+ "tox_endpoints": [
51
+ "NR-AR",
52
+ "NR-AR-LBD",
53
+ "NR-AhR",
54
+ "NR-Aromatase",
55
+ "NR-ER",
56
+ "NR-ER-LBD",
57
+ "NR-PPAR-gamma",
58
+ "SR-ARE",
59
+ "SR-ATAD5",
60
+ "SR-HSE",
61
+ "SR-MMP",
62
+ "SR-p53",
63
+ ],
64
+ }
65
+
66
+
67
+ @app.get("/healthz")
68
+ def healthz():
69
+ return {"ok": True}
70
+
71
+
72
+ @app.post("/predict", response_model=Response)
73
+ def predict(request: Request):
74
+ predictions = predict_func(request.smiles)
75
+ return {
76
+ "predictions": predictions,
77
+ "model_info": {"name": "Tox21 Random Forest Classifier", "version": "0.1.0"},
78
+ }
checkpoints/.gitkeep ADDED
File without changes
config/config.json ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "seed": 0,
3
+ "debug": "false",
4
+ "device": "cpu",
5
+
6
+ "log_folder": "logs/",
7
+
8
+ "data_folder": "data/",
9
+ "cvfold": 4,
10
+ "ecfp" : {
11
+ "radius": 3,
12
+ "fpsize": 8192
13
+ },
14
+ "merge_train_val": "true",
15
+ "descriptors": ["mhfps", "tox", "maccs", "rdkit_descrs"],
16
+ "feature_selection": {
17
+ "use": "true",
18
+ "min_var": 0.01,
19
+ "max_corr": 0.95,
20
+ "max_features": -1,
21
+ "min_var__feature_keys": ["mhfp", "tox", "maccs", "rdkit_descrs"],
22
+ "max_corr__feature_keys": ["mhfp", "tox", "maccs", "rdkit_descrs"],
23
+ "min_var__independent_keys": "false",
24
+ "max_corr__independent_keys": "false"
25
+ },
26
+ "feature_quantilization": {
27
+ "use": "true",
28
+ "feature_keys": ["rdkit_descrs"]
29
+ },
30
+ "max_samples": -1,
31
+ "scaler": "standard",
32
+ "preprocessor_path": "checkpoints/preprocessor.joblib",
33
+
34
+ "ckpt_path": "checkpoints/rf_alltasks.joblib",
35
+ "model_config": {
36
+ "NR-AR": {
37
+ "max_depth": "none",
38
+ "max_features": "sqrt",
39
+ "min_samples_leaf": 1,
40
+ "min_samples_split": 5,
41
+ "n_estimators": 1000
42
+ },
43
+ "NR-AR-LBD": {
44
+ "max_depth": 12,
45
+ "max_features": "sqrt",
46
+ "min_samples_leaf": 1,
47
+ "min_samples_split": 5,
48
+ "n_estimators": 1000
49
+ },
50
+ "NR-AhR": {
51
+ "max_depth": "none",
52
+ "max_features": "log2",
53
+ "min_samples_leaf": 1,
54
+ "min_samples_split": 2,
55
+ "n_estimators": 1000
56
+ },
57
+ "NR-Aromatase": {
58
+ "max_depth": "none",
59
+ "max_features": "sqrt",
60
+ "min_samples_leaf": 4,
61
+ "min_samples_split": 12,
62
+ "n_estimators": 1000
63
+ },
64
+ "NR-ER": {
65
+ "max_depth": 10,
66
+ "max_features": "sqrt",
67
+ "min_samples_leaf": 1,
68
+ "min_samples_split": 2,
69
+ "n_estimators": 1000
70
+ },
71
+ "NR-ER-LBD": {
72
+ "max_depth": 8,
73
+ "max_features": "sqrt",
74
+ "min_samples_leaf": 2,
75
+ "min_samples_split": 5,
76
+ "n_estimators": 1000
77
+ },
78
+ "NR-PPAR-gamma": {
79
+ "max_depth": "none",
80
+ "max_features": "log2",
81
+ "min_samples_leaf": 1,
82
+ "min_samples_split": 2,
83
+ "n_estimators": 1000
84
+ },
85
+ "SR-ARE": {
86
+ "max_depth": "none",
87
+ "max_features": "sqrt",
88
+ "min_samples_leaf": 1,
89
+ "min_samples_split": 5,
90
+ "n_estimators": 1000
91
+ },
92
+ "SR-ATAD5": {
93
+ "max_depth": "none",
94
+ "max_features": "sqrt",
95
+ "min_samples_leaf": 1,
96
+ "min_samples_split": 2,
97
+ "n_estimators": 1000
98
+ },
99
+ "SR-HSE": {
100
+ "max_depth": 16,
101
+ "max_features": "log2",
102
+ "min_samples_leaf": 1,
103
+ "min_samples_split": 2,
104
+ "n_estimators": 1000
105
+ },
106
+ "SR-MMP": {
107
+ "max_depth": "none",
108
+ "max_features": "sqrt",
109
+ "min_samples_leaf": 2,
110
+ "min_samples_split": 2,
111
+ "n_estimators": 1000
112
+ },
113
+ "SR-p53": {
114
+ "max_depth": "none",
115
+ "max_features": "sqrt",
116
+ "min_samples_leaf": 1,
117
+ "min_samples_split": 2,
118
+ "n_estimators": 1000
119
+ }
120
+ }
121
+ }
logs/train_2025-11-21_17-59-45.log ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-11-21 17:59:45,120 [INFO] Config: {'seed': 0, 'debug': False, 'device': 'cpu', 'log_folder': 'logs/', 'data_folder': 'data/', 'cvfold': 4, 'ecfp': {'radius': 3, 'fpsize': 8192}, 'merge_train_val': True, 'descriptors': ['mhfps', 'tox', 'maccs', 'rdkit_descrs'], 'feature_selection': {'use': True, 'min_var': 0.01, 'max_corr': 0.95, 'max_features': -1, 'min_var__feature_keys': ['mhfp', 'tox', 'maccs', 'rdkit_descrs'], 'max_corr__feature_keys': ['mhfp', 'tox', 'maccs', 'rdkit_descrs'], 'min_var__independent_keys': False, 'max_corr__independent_keys': False}, 'feature_quantilization': {'use': True, 'feature_keys': ['rdkit_descrs']}, 'max_samples': -1, 'scaler': 'standard', 'preprocessor_path': 'checkpoints/preprocessor.joblib', 'ckpt_path': 'checkpoints/rf_alltasks.joblib', 'model_config': {'NR-AR': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}, 'NR-AR-LBD': {'max_depth': 12, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}, 'NR-AhR': {'max_depth': None, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}, 'NR-Aromatase': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 4, 'min_samples_split': 12, 'n_estimators': 1000}, 'NR-ER': {'max_depth': 10, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}, 'NR-ER-LBD': {'max_depth': 8, 'max_features': 'sqrt', 'min_samples_leaf': 2, 'min_samples_split': 5, 'n_estimators': 1000}, 'NR-PPAR-gamma': {'max_depth': None, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}, 'SR-ARE': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}, 'SR-ATAD5': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}, 'SR-HSE': {'max_depth': 16, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}, 'SR-MMP': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 2, 'min_samples_split': 2, 'n_estimators': 1000}, 'SR-p53': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}}}
2
+ 2025-11-21 17:59:45,120 [INFO] Model config:
3
+ Model config:
4
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}
5
+ {'max_depth': 12, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}
6
+ {'max_depth': None, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
7
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 4, 'min_samples_split': 12, 'n_estimators': 1000}
8
+ {'max_depth': 10, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
9
+ {'max_depth': 8, 'max_features': 'sqrt', 'min_samples_leaf': 2, 'min_samples_split': 5, 'n_estimators': 1000}
10
+ {'max_depth': None, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
11
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}
12
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
13
+ {'max_depth': 16, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
14
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 2, 'min_samples_split': 2, 'n_estimators': 1000}
15
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
logs/train_2025-11-21_21-13-33.log ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-11-21 21:13:33,779 [INFO] Config: {'seed': 0, 'debug': False, 'device': 'cpu', 'log_folder': 'logs/', 'data_folder': 'data/', 'cvfold': 4, 'ecfp': {'radius': 3, 'fpsize': 8192}, 'merge_train_val': True, 'descriptors': ['mhfps', 'tox', 'maccs', 'rdkit_descrs'], 'feature_selection': {'use': True, 'min_var': 0.01, 'max_corr': 0.95, 'max_features': -1, 'min_var__feature_keys': ['mhfp', 'tox', 'maccs', 'rdkit_descrs'], 'max_corr__feature_keys': ['mhfp', 'tox', 'maccs', 'rdkit_descrs'], 'min_var__independent_keys': False, 'max_corr__independent_keys': False}, 'feature_quantilization': {'use': True, 'feature_keys': ['rdkit_descrs']}, 'max_samples': -1, 'scaler': 'standard', 'preprocessor_path': 'checkpoints/preprocessor.joblib', 'ckpt_path': 'checkpoints/rf_alltasks.joblib', 'model_config': {'NR-AR': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}, 'NR-AR-LBD': {'max_depth': 12, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}, 'NR-AhR': {'max_depth': None, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}, 'NR-Aromatase': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 4, 'min_samples_split': 12, 'n_estimators': 1000}, 'NR-ER': {'max_depth': 10, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}, 'NR-ER-LBD': {'max_depth': 8, 'max_features': 'sqrt', 'min_samples_leaf': 2, 'min_samples_split': 5, 'n_estimators': 1000}, 'NR-PPAR-gamma': {'max_depth': None, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}, 'SR-ARE': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}, 'SR-ATAD5': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}, 'SR-HSE': {'max_depth': 16, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}, 'SR-MMP': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 2, 'min_samples_split': 2, 'n_estimators': 1000}, 'SR-p53': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}}}
2
+ 2025-11-21 21:13:33,779 [INFO] Model config:
3
+ Model config:
4
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}
5
+ {'max_depth': 12, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}
6
+ {'max_depth': None, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
7
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 4, 'min_samples_split': 12, 'n_estimators': 1000}
8
+ {'max_depth': 10, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
9
+ {'max_depth': 8, 'max_features': 'sqrt', 'min_samples_leaf': 2, 'min_samples_split': 5, 'n_estimators': 1000}
10
+ {'max_depth': None, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
11
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}
12
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
13
+ {'max_depth': 16, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
14
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 2, 'min_samples_split': 2, 'n_estimators': 1000}
15
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
logs/train_2025-11-21_21-29-10.log ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-11-21 21:29:10,146 [INFO] Config: {'seed': 0, 'debug': False, 'device': 'cpu', 'log_folder': 'logs/', 'data_folder': 'data/', 'cvfold': 4, 'ecfp': {'radius': 3, 'fpsize': 8192}, 'merge_train_val': True, 'descriptors': ['ecfps', 'tox', 'maccs', 'rdkit_descrs'], 'feature_selection': {'use': True, 'min_var': 0.01, 'max_corr': 0.95, 'max_features': -1, 'min_var__feature_keys': ['ecfp', 'tox', 'maccs', 'rdkit_descrs'], 'max_corr__feature_keys': ['ecfp', 'tox', 'maccs', 'rdkit_descrs'], 'min_var__independent_keys': False, 'max_corr__independent_keys': False}, 'feature_quantilization': {'use': True, 'feature_keys': ['rdkit_descrs']}, 'max_samples': -1, 'scaler': 'standard', 'preprocessor_path': 'checkpoints/preprocessor.joblib', 'ckpt_path': 'checkpoints/rf_alltasks.joblib', 'model_config': {'NR-AR': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}, 'NR-AR-LBD': {'max_depth': 12, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}, 'NR-AhR': {'max_depth': None, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}, 'NR-Aromatase': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 4, 'min_samples_split': 12, 'n_estimators': 1000}, 'NR-ER': {'max_depth': 10, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}, 'NR-ER-LBD': {'max_depth': 8, 'max_features': 'sqrt', 'min_samples_leaf': 2, 'min_samples_split': 5, 'n_estimators': 1000}, 'NR-PPAR-gamma': {'max_depth': None, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}, 'SR-ARE': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}, 'SR-ATAD5': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}, 'SR-HSE': {'max_depth': 16, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}, 'SR-MMP': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 2, 'min_samples_split': 2, 'n_estimators': 1000}, 'SR-p53': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}}}
2
+ 2025-11-21 21:29:10,146 [INFO] Model config:
3
+ Model config:
4
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}
5
+ {'max_depth': 12, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}
6
+ {'max_depth': None, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
7
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 4, 'min_samples_split': 12, 'n_estimators': 1000}
8
+ {'max_depth': 10, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
9
+ {'max_depth': 8, 'max_features': 'sqrt', 'min_samples_leaf': 2, 'min_samples_split': 5, 'n_estimators': 1000}
10
+ {'max_depth': None, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
11
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}
12
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
13
+ {'max_depth': 16, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
14
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 2, 'min_samples_split': 2, 'n_estimators': 1000}
15
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
16
+ 2025-11-21 21:29:10,953 [INFO] Fitted RandomForestClassifier will be saved as: checkpoints/rf_alltasks.joblib
17
+ 2025-11-21 21:29:12,299 [INFO] Start training.
18
+ 2025-11-21 21:29:12,299 [INFO] Fit task NR-AR using 9645 samples
logs/train_2025-11-21_21-38-38.log ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-11-21 21:38:38,350 [INFO] Config: {'seed': 0, 'debug': False, 'device': 'cpu', 'log_folder': 'logs/', 'data_folder': 'data/', 'cvfold': 4, 'ecfp': {'radius': 3, 'fpsize': 8192}, 'merge_train_val': True, 'descriptors': ['mhfps', 'tox', 'maccs', 'rdkit_descrs'], 'feature_selection': {'use': True, 'min_var': 0.01, 'max_corr': 0.95, 'max_features': -1, 'min_var__feature_keys': ['mhfp', 'tox', 'maccs', 'rdkit_descrs'], 'max_corr__feature_keys': ['mhfp', 'tox', 'maccs', 'rdkit_descrs'], 'min_var__independent_keys': False, 'max_corr__independent_keys': False}, 'feature_quantilization': {'use': True, 'feature_keys': ['rdkit_descrs']}, 'max_samples': -1, 'scaler': 'standard', 'preprocessor_path': 'checkpoints/preprocessor.joblib', 'ckpt_path': 'checkpoints/rf_alltasks.joblib', 'model_config': {'NR-AR': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}, 'NR-AR-LBD': {'max_depth': 12, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}, 'NR-AhR': {'max_depth': None, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}, 'NR-Aromatase': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 4, 'min_samples_split': 12, 'n_estimators': 1000}, 'NR-ER': {'max_depth': 10, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}, 'NR-ER-LBD': {'max_depth': 8, 'max_features': 'sqrt', 'min_samples_leaf': 2, 'min_samples_split': 5, 'n_estimators': 1000}, 'NR-PPAR-gamma': {'max_depth': None, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}, 'SR-ARE': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}, 'SR-ATAD5': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}, 'SR-HSE': {'max_depth': 16, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}, 'SR-MMP': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 2, 'min_samples_split': 2, 'n_estimators': 1000}, 'SR-p53': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}}}
2
+ 2025-11-21 21:38:38,350 [INFO] Model config:
3
+ Model config:
4
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}
5
+ {'max_depth': 12, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}
6
+ {'max_depth': None, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
7
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 4, 'min_samples_split': 12, 'n_estimators': 1000}
8
+ {'max_depth': 10, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
9
+ {'max_depth': 8, 'max_features': 'sqrt', 'min_samples_leaf': 2, 'min_samples_split': 5, 'n_estimators': 1000}
10
+ {'max_depth': None, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
11
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}
12
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
13
+ {'max_depth': 16, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
14
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 2, 'min_samples_split': 2, 'n_estimators': 1000}
15
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
logs/train_2025-11-21_21-42-58.log ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-11-21 21:42:58,767 [INFO] Config: {'seed': 0, 'debug': False, 'device': 'cpu', 'log_folder': 'logs/', 'data_folder': 'data/', 'cvfold': 4, 'ecfp': {'radius': 3, 'fpsize': 8192}, 'merge_train_val': True, 'descriptors': ['mhfps', 'tox', 'maccs', 'rdkit_descrs'], 'feature_selection': {'use': True, 'min_var': 0.01, 'max_corr': 0.95, 'max_features': -1, 'min_var__feature_keys': ['mhfp', 'tox', 'maccs', 'rdkit_descrs'], 'max_corr__feature_keys': ['mhfp', 'tox', 'maccs', 'rdkit_descrs'], 'min_var__independent_keys': False, 'max_corr__independent_keys': False}, 'feature_quantilization': {'use': True, 'feature_keys': ['rdkit_descrs']}, 'max_samples': -1, 'scaler': 'standard', 'preprocessor_path': 'checkpoints/preprocessor.joblib', 'ckpt_path': 'checkpoints/rf_alltasks.joblib', 'model_config': {'NR-AR': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}, 'NR-AR-LBD': {'max_depth': 12, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}, 'NR-AhR': {'max_depth': None, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}, 'NR-Aromatase': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 4, 'min_samples_split': 12, 'n_estimators': 1000}, 'NR-ER': {'max_depth': 10, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}, 'NR-ER-LBD': {'max_depth': 8, 'max_features': 'sqrt', 'min_samples_leaf': 2, 'min_samples_split': 5, 'n_estimators': 1000}, 'NR-PPAR-gamma': {'max_depth': None, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}, 'SR-ARE': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}, 'SR-ATAD5': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}, 'SR-HSE': {'max_depth': 16, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}, 'SR-MMP': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 2, 'min_samples_split': 2, 'n_estimators': 1000}, 'SR-p53': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}}}
2
+ 2025-11-21 21:42:58,767 [INFO] Model config:
3
+ Model config:
4
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}
5
+ {'max_depth': 12, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}
6
+ {'max_depth': None, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
7
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 4, 'min_samples_split': 12, 'n_estimators': 1000}
8
+ {'max_depth': 10, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
9
+ {'max_depth': 8, 'max_features': 'sqrt', 'min_samples_leaf': 2, 'min_samples_split': 5, 'n_estimators': 1000}
10
+ {'max_depth': None, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
11
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}
12
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
13
+ {'max_depth': 16, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
14
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 2, 'min_samples_split': 2, 'n_estimators': 1000}
15
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
16
+ 2025-11-21 21:42:59,576 [INFO] Fitted RandomForestClassifier will be saved as: checkpoints/rf_alltasks.joblib
17
+ 2025-11-21 21:43:00,826 [INFO] Start training.
18
+ 2025-11-21 21:43:00,826 [INFO] Fit task NR-AR using 9645 samples
19
+ 2025-11-21 21:43:07,192 [INFO] Fit task NR-AR-LBD using 8844 samples
20
+ 2025-11-21 21:43:12,861 [INFO] Fit task NR-AhR using 8432 samples
21
+ 2025-11-21 21:43:15,695 [INFO] Fit task NR-Aromatase using 7431 samples
22
+ 2025-11-21 21:43:20,528 [INFO] Fit task NR-ER using 7953 samples
23
+ 2025-11-21 21:43:26,434 [INFO] Fit task NR-ER-LBD using 9031 samples
24
+ 2025-11-21 21:43:31,881 [INFO] Fit task NR-PPAR-gamma using 8442 samples
25
+ 2025-11-21 21:43:33,798 [INFO] Fit task SR-ARE using 7395 samples
26
+ 2025-11-21 21:43:40,050 [INFO] Fit task SR-ATAD5 using 9354 samples
27
+ 2025-11-21 21:43:46,811 [INFO] Fit task SR-HSE using 8409 samples
28
+ 2025-11-21 21:43:49,279 [INFO] Fit task SR-MMP using 7551 samples
29
+ 2025-11-21 21:43:55,132 [INFO] Fit task SR-p53 using 8894 samples
30
+ 2025-11-21 21:44:02,332 [INFO] Finished training.
logs/train_2025-11-21_21-45-08.log ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-11-21 21:45:08,091 [INFO] Config: {'seed': 0, 'debug': False, 'device': 'cpu', 'log_folder': 'logs/', 'data_folder': 'data/', 'cvfold': 4, 'ecfp': {'radius': 3, 'fpsize': 8192}, 'merge_train_val': True, 'descriptors': ['mhfps', 'tox', 'maccs', 'rdkit_descrs'], 'feature_selection': {'use': True, 'min_var': 0.01, 'max_corr': 0.95, 'max_features': -1, 'min_var__feature_keys': ['mhfp', 'tox', 'maccs', 'rdkit_descrs'], 'max_corr__feature_keys': ['mhfp', 'tox', 'maccs', 'rdkit_descrs'], 'min_var__independent_keys': False, 'max_corr__independent_keys': False}, 'feature_quantilization': {'use': True, 'feature_keys': ['rdkit_descrs']}, 'max_samples': -1, 'scaler': 'standard', 'preprocessor_path': 'checkpoints/preprocessor.joblib', 'ckpt_path': 'checkpoints/rf_alltasks.joblib', 'model_config': {'NR-AR': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}, 'NR-AR-LBD': {'max_depth': 12, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}, 'NR-AhR': {'max_depth': None, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}, 'NR-Aromatase': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 4, 'min_samples_split': 12, 'n_estimators': 1000}, 'NR-ER': {'max_depth': 10, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}, 'NR-ER-LBD': {'max_depth': 8, 'max_features': 'sqrt', 'min_samples_leaf': 2, 'min_samples_split': 5, 'n_estimators': 1000}, 'NR-PPAR-gamma': {'max_depth': None, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}, 'SR-ARE': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}, 'SR-ATAD5': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}, 'SR-HSE': {'max_depth': 16, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}, 'SR-MMP': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 2, 'min_samples_split': 2, 'n_estimators': 1000}, 'SR-p53': {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}}}
2
+ 2025-11-21 21:45:08,092 [INFO] Model config:
3
+ Model config:
4
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}
5
+ {'max_depth': 12, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}
6
+ {'max_depth': None, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
7
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 4, 'min_samples_split': 12, 'n_estimators': 1000}
8
+ {'max_depth': 10, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
9
+ {'max_depth': 8, 'max_features': 'sqrt', 'min_samples_leaf': 2, 'min_samples_split': 5, 'n_estimators': 1000}
10
+ {'max_depth': None, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
11
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 1000}
12
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
13
+ {'max_depth': 16, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
14
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 2, 'min_samples_split': 2, 'n_estimators': 1000}
15
+ {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 1000}
16
+ 2025-11-21 21:45:08,925 [INFO] Fitted RandomForestClassifier will be saved as: checkpoints/rf_alltasks.joblib
17
+ 2025-11-21 21:45:10,035 [INFO] Start training.
18
+ 2025-11-21 21:45:10,036 [INFO] Fit task NR-AR using 9645 samples
19
+ 2025-11-21 21:45:16,392 [INFO] Fit task NR-AR-LBD using 8844 samples
20
+ 2025-11-21 21:45:22,091 [INFO] Fit task NR-AhR using 8432 samples
21
+ 2025-11-21 21:45:24,931 [INFO] Fit task NR-Aromatase using 7431 samples
22
+ 2025-11-21 21:45:29,749 [INFO] Fit task NR-ER using 7953 samples
23
+ 2025-11-21 21:45:35,716 [INFO] Fit task NR-ER-LBD using 9031 samples
24
+ 2025-11-21 21:45:41,136 [INFO] Fit task NR-PPAR-gamma using 8442 samples
25
+ 2025-11-21 21:45:43,052 [INFO] Fit task SR-ARE using 7395 samples
26
+ 2025-11-21 21:45:49,365 [INFO] Fit task SR-ATAD5 using 9354 samples
27
+ 2025-11-21 21:45:56,269 [INFO] Fit task SR-HSE using 8409 samples
28
+ 2025-11-21 21:45:58,780 [INFO] Fit task SR-MMP using 7551 samples
29
+ 2025-11-21 21:46:04,725 [INFO] Fit task SR-p53 using 8894 samples
30
+ 2025-11-21 21:46:12,055 [INFO] Finished training.
31
+ 2025-11-21 21:46:13,441 [INFO] Save model as: checkpoints/rf_alltasks.joblib
32
+ 2025-11-21 21:46:13,459 [INFO] Save preprocessor as: checkpoints/preprocessor.joblib
predict.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This files includes a predict function for the Tox21.
3
+ As an input it takes a list of SMILES and it outputs a nested dictionary with
4
+ SMILES and target names as keys.
5
+ """
6
+
7
+ # ---------------------------------------------------------------------------------------
8
+ # Dependencies
9
+ import json
10
+ import copy
11
+ from collections import defaultdict
12
+
13
+ import joblib
14
+ import numpy as np
15
+ from tqdm import tqdm
16
+
17
+ from src.model import Tox21RFClassifier
18
+ from src.preprocess import create_descriptors, FeaturePreprocessor
19
+ from src.utils import TASKS, normalize_config
20
+
21
+
22
+ # ---------------------------------------------------------------------------------------
23
+ CONFIG_FILE = "./config/config.json"
24
+
25
+
26
+ def predict(
27
+ smiles_list: list[str], default_prediction: float = 0.5
28
+ ) -> dict[str, dict[str, float]]:
29
+ """Applies the classifier to a list of SMILES strings. Returns prediction=0.0 for
30
+ any molecule that could not be cleaned.
31
+
32
+ Args:
33
+ smiles_list (list[str]): list of SMILES strings
34
+
35
+ Returns:
36
+ dict: nested prediction dictionary, following {'<smiles>': {'<target>': <pred>}}
37
+ """
38
+ print(f"Received {len(smiles_list)} SMILES strings")
39
+
40
+ with open(CONFIG_FILE, "r") as f:
41
+ config = json.load(f)
42
+ config = normalize_config(config)
43
+
44
+ features, is_clean = create_descriptors(
45
+ smiles_list, config["descriptors"], **config["ecfp"]
46
+ )
47
+ print(f"Created descriptors for {sum(is_clean)} molecules.")
48
+ print(f"{len(is_clean) - sum(is_clean)} molecules removed during cleaning")
49
+
50
+ # setup model
51
+ model = Tox21RFClassifier()
52
+ preprocessor = FeaturePreprocessor(
53
+ feature_selection_config=config["feature_selection"],
54
+ feature_quantilization_config=config["feature_quantilization"],
55
+ descriptors=config["descriptors"],
56
+ max_samples=config["max_samples"],
57
+ scaler=config["scaler"],
58
+ )
59
+
60
+ model.load(config["ckpt_path"])
61
+ print(f"Loaded model from {config['ckpt_path']}")
62
+
63
+ state = joblib.load(config["preprocessor_path"])
64
+ preprocessor.set_state(state)
65
+ print(f"Loaded preprocessor from {config['preprocessor_path']}")
66
+
67
+ # make predicitons
68
+ predictions = defaultdict(dict)
69
+
70
+ print(f"Create predictions:")
71
+ preds = []
72
+ for target in tqdm(TASKS):
73
+ X = copy.deepcopy(features)
74
+ X = {descr: array[is_clean] for descr, array in X.items()}
75
+ X = preprocessor.transform(X)
76
+
77
+ preds = np.empty_like(is_clean, dtype=np.float64)
78
+ preds[~is_clean] = default_prediction
79
+ preds[is_clean] = model.predict(target, X)
80
+
81
+ for smiles, pred in zip(smiles_list, preds):
82
+ predictions[smiles][target] = float(pred)
83
+ if config["debug"]:
84
+ break
85
+
86
+ return predictions
preprocess.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pipeline taken from https://huggingface.co/spaces/ml-jku/mhnfs/blob/main/src/data_preprocessing/create_descriptors.py
2
+
3
+ """
4
+ This files includes a the data processing for Tox21.
5
+ As an input it takes a list of SMILES and it outputs a nested dictionary with
6
+ SMILES and target names as keys.
7
+ """
8
+
9
+ import os
10
+ import json
11
+ import argparse
12
+
13
+ import numpy as np
14
+
15
+ from src.preprocess import create_descriptors, get_tox21_split
16
+ from src.utils import TASKS, HF_TOKEN, create_dir, normalize_config
17
+
18
+ parser = argparse.ArgumentParser(
19
+ description="Data preprocessing script for the Tox21 dataset"
20
+ )
21
+
22
+ parser.add_argument(
23
+ "--config",
24
+ type=str,
25
+ default="config/config.json",
26
+ )
27
+
28
+
29
+ def main(config):
30
+ """Create molecule descriptors for HF Tox21 dataset"""
31
+ ds = get_tox21_split(HF_TOKEN, cvfold=config["cvfold"])
32
+
33
+ splits = ["train", "validation"]
34
+ for split in splits:
35
+
36
+ print(f"Preprocess {split} molecules")
37
+
38
+ ds_split = ds[split]
39
+ smiles = list(ds_split["smiles"])
40
+
41
+ features, clean_mol_mask = create_descriptors(
42
+ smiles, config["descriptors"], **config["ecfp"]
43
+ )
44
+
45
+ labels = []
46
+ for task in TASKS:
47
+ labels.append(ds_split[task].to_numpy())
48
+ labels = np.stack(labels, axis=1)
49
+
50
+ save_path = os.path.join(config["data_folder"], f"tox21_{split}_cv4.npz")
51
+ with open(save_path, "wb") as f:
52
+ np.savez(
53
+ f,
54
+ clean_mol_mask=clean_mol_mask,
55
+ labels=labels,
56
+ **features,
57
+ )
58
+ print(f"Saved preprocessed {split} split under {save_path}")
59
+ print("Preprocessing finished successfully")
60
+
61
+
62
+ if __name__ == "__main__":
63
+ args = parser.parse_args()
64
+
65
+ with open(args.config, "r") as f:
66
+ config = json.load(f)
67
+ config = normalize_config(config)
68
+
69
+ create_dir(config["data_folder"])
70
+ main(config)
pyproject.toml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "tox-21-mhfp-classifier"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.13"
7
+ dependencies = [
8
+ "datasets>=4.4.1",
9
+ "fastapi>=0.121.3",
10
+ "joblib==1.5.2",
11
+ "mhfp==1.9.6",
12
+ "numpy==2.3.3",
13
+ "rdkit==2025.9.1",
14
+ "scikit-learn==1.7.1",
15
+ "statsmodels==0.14.5",
16
+ "tabulate>=0.9.0",
17
+ "torch==2.8.0",
18
+ "uvicorn[standard]>=0.38.0",
19
+ ]
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ statsmodels==0.14.5
4
+ rdkit==2025.09.1
5
+ numpy==2.3.3
6
+ scikit-learn==1.7.1
7
+ joblib==1.5.2
8
+ tabulate
9
+ datasets
10
+ torch==2.8.0
11
+ mhfp==1.9.6
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (152 Bytes). View file
 
src/__pycache__/model.cpython-313.pyc ADDED
Binary file (3.8 kB). View file
 
src/__pycache__/preprocess.cpython-313.pyc ADDED
Binary file (33 kB). View file
 
src/__pycache__/utils.cpython-313.pyc ADDED
Binary file (15.8 kB). View file
 
src/model.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This files includes a RF model for Tox21.
3
+ As an input it takes a list of SMILES and it outputs a nested dictionary with
4
+ SMILES and target names as keys.
5
+ """
6
+
7
+ # ---------------------------------------------------------------------------------------
8
+ # Dependencies
9
+ import joblib
10
+ import numpy as np
11
+ from sklearn.ensemble import RandomForestClassifier
12
+
13
+ from .utils import TASKS
14
+
15
+
16
+ # ---------------------------------------------------------------------------------------
17
+ class Tox21RFClassifier:
18
+ """A random forest classifier that assigns a toxicity score to a given SMILES string."""
19
+
20
+ def __init__(self, seed: int = 42, config: dict = None):
21
+ """Initialize a random forest classifier for each of the 12 Tox21 tasks.
22
+
23
+ Args:
24
+ seed (int, optional): seed for RF to ensure reproducibility. Defaults to 42.
25
+ """
26
+ self.tasks = TASKS
27
+
28
+ self.models = {
29
+ task: RandomForestClassifier(
30
+ random_state=seed,
31
+ n_jobs=8,
32
+ **({"n_estimators": 1000} if config is None else config[task]),
33
+ )
34
+ for task in self.tasks
35
+ }
36
+
37
+ def load(self, path: str) -> None:
38
+ """Load model from filepath
39
+
40
+ Args:
41
+ path (str): filepath to model checkpoint
42
+ """
43
+ self.models = joblib.load(path)
44
+
45
+ def save(self, path: str) -> None:
46
+ """Save model to filepath
47
+
48
+ Args:
49
+ path (str): filepath to model checkpoint
50
+ """
51
+ joblib.dump(self.models, path)
52
+
53
+ def fit(self, task: str, X: np.ndarray, y: np.ndarray) -> None:
54
+ """Train the random forest for a given task
55
+
56
+ Args:
57
+ task (str): task to train
58
+ X (np.ndarray): training features
59
+ y (np.ndarray): training labels
60
+ """
61
+ assert task in self.tasks, f"Unknown task: {task}"
62
+ _X, _y = X.copy(), y.copy()
63
+ self.models[task].fit(_X, _y)
64
+
65
+ def predict(self, task: str, X: np.ndarray) -> np.ndarray:
66
+ """Predicts labels for a given Tox21 target using molecule features
67
+
68
+ Args:
69
+ task (str): the Tox21 target to predict for
70
+ X (np.ndarray): molecule features used for prediction
71
+
72
+ Returns:
73
+ np.ndarray: predicted probability for positive class
74
+ """
75
+ assert task in self.tasks, f"Unknown task: {task}"
76
+ assert (
77
+ len(X.shape) == 2
78
+ ), f"Function expects 2D np.array. Current shape: {X.shape}"
79
+ _X = X.copy()
80
+ return self.models[task].predict_proba(_X)[:, 1]
src/preprocess.py ADDED
@@ -0,0 +1,704 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ from typing import Any
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ from datasets import load_dataset
9
+ from sklearn.base import BaseEstimator, TransformerMixin
10
+ from sklearn.feature_selection import VarianceThreshold
11
+ from sklearn.preprocessing import StandardScaler, FunctionTransformer
12
+ from statsmodels.distributions.empirical_distribution import ECDF
13
+
14
+ from rdkit import Chem, DataStructs
15
+ from rdkit.Chem import Descriptors, rdFingerprintGenerator, MACCSkeys
16
+ from rdkit.Chem.rdchem import Mol
17
+ from rdkit.Chem.rdMHFPFingerprint import MHFPEncoder
18
+
19
+ from .utils import USED_200_DESCR, TOX_SMARTS_PATH, Standardizer, FeatureDictMixin
20
+
21
+
22
+ class SquashScaler(TransformerMixin, BaseEstimator):
23
+ """
24
+ Scaler that performs sequential standardization, nonlinearity (tanh), and
25
+ re-standardization. Inspired by DeepTox (Mayr et al., 2016)
26
+ """
27
+
28
+ def __init__(self):
29
+ self.scaler1 = StandardScaler()
30
+ self.scaler2 = StandardScaler()
31
+
32
+ def fit(self, X):
33
+ _X = X.copy()
34
+ _X = self.scaler1.fit_transform(_X)
35
+ _X = np.tanh(_X)
36
+ _X = self.scaler2.fit(_X)
37
+ self.is_fitted_ = True
38
+ return self
39
+
40
+ def transform(self, X):
41
+ _X = X.copy()
42
+ _X = self.scaler1.transform(_X)
43
+ _X = np.tanh(_X)
44
+ return self.scaler2.transform(_X)
45
+
46
+
47
+ SCALER_REGISTRY = {
48
+ None: FunctionTransformer,
49
+ "standard": StandardScaler,
50
+ "squash": SquashScaler,
51
+ }
52
+
53
+
54
+ class SubSampler(TransformerMixin, BaseEstimator):
55
+ """
56
+ Preprocessor that randomly samples `max_samples` from data.
57
+
58
+ Args:
59
+ max_samples (int): Maximum allowed samples. If -1, all samples are retained.
60
+
61
+ Input:
62
+ np.ndarray: A 2D NumPy array of shape (n_samples, n_features).
63
+
64
+ Output:
65
+ np.ndarray: Subsampled array of shape (min(n_samples, max_samples), n_features).
66
+ """
67
+
68
+ def __init__(self, *, max_samples=-1):
69
+ self.max_samples = max_samples
70
+ self.is_fitted_ = True
71
+
72
+ def fit(self, X: np.ndarray, y: np.ndarray | None = None):
73
+ return self
74
+
75
+ def transform(
76
+ self, X: np.ndarray, y: np.ndarray | None = None
77
+ ) -> np.ndarray | tuple[np.ndarray]:
78
+
79
+ _X = X.copy()
80
+ _y = y.copy() if y is not None else None
81
+
82
+ if self.max_samples > 0 and _X.shape[0] > self.max_samples:
83
+ resample_idxs = np.random.choice(
84
+ np.arange(_X.shape[0]), size=(self.max_samples,), replace=True
85
+ )
86
+ _X = _X[resample_idxs]
87
+ _y = _y[resample_idxs] if _y is not None else None
88
+
89
+ if _y is None:
90
+ return _X
91
+ return _X, _y
92
+
93
+
94
+ class FeatureSelector(FeatureDictMixin, TransformerMixin, BaseEstimator):
95
+ """
96
+ Preprocessor that performs feature selection based on variance and correlation.
97
+
98
+ This transformer selects features that:
99
+ 1. Have variance above a specified threshold.
100
+ 2. Are below a given pairwise correlation threshold.
101
+ 3. Among the remaining features, keeps only the top `max_features` with the highest variance.
102
+
103
+ The input and output are both dictionaries mapping feature types to their corresponding
104
+ feature matrices.
105
+
106
+ Args:
107
+ min_var (float): Minimum variance required for a feature to be retained.
108
+ max_corr (float): Maximum allowed correlation between features.
109
+ Features exceeding this threshold with others are removed.
110
+ max_features (int): Maximum number of features to keep after filtering.
111
+ If -1, all remaining features are retained.
112
+ feature_keys (list[str]): Features to apply feature selection to.
113
+ independent_keys (bool): Apply filtering only within features types.
114
+
115
+ Input:
116
+ dict[str, np.ndarray]: A dictionary where each key corresponds to a feature type
117
+ and each value is a 2D NumPy array of shape (n_samples, n_features).
118
+
119
+ Output:
120
+ dict[str, np.ndarray]: A dictionary with the same keys as the input,
121
+ containing only the selected features for each feature type.
122
+ """
123
+
124
+ def __init__(
125
+ self,
126
+ *,
127
+ min_var=0.0,
128
+ max_corr=1.0,
129
+ max_features=-1,
130
+ feature_keys=None,
131
+ min_var__feature_keys=None,
132
+ max_corr__feature_keys=None,
133
+ max_features__feature_keys=None,
134
+ min_var__independent_keys=False,
135
+ max_corr__independent_keys=False,
136
+ max_features__independent_keys=False,
137
+ ):
138
+ self.min_var = min_var
139
+ self.max_corr = max_corr
140
+ self.max_features = max_features
141
+
142
+ self.min_var__feature_keys = min_var__feature_keys
143
+ self.max_corr__feature_keys = max_corr__feature_keys
144
+ self.max_features__feature_keys = max_features__feature_keys
145
+
146
+ self.min_var__independent_keys = min_var__independent_keys
147
+ self.max_corr__independent_keys = max_corr__independent_keys
148
+ self.max_features__independent_keys = max_features__independent_keys
149
+
150
+ super().__init__(feature_keys=feature_keys)
151
+
152
+ def _get_min_var_mask(self, X: np.ndarray, *args) -> np.ndarray:
153
+ var_thresh = VarianceThreshold(threshold=self.min_var)
154
+ return var_thresh.fit(X).get_support() # mask
155
+
156
+ def _get_max_corr_mask(
157
+ self, X: np.ndarray, prev_feature_mask: np.ndarray
158
+ ) -> np.ndarray:
159
+ _prev_feature_mask = prev_feature_mask.copy()
160
+ corr_matrix = np.corrcoef(X[:, _prev_feature_mask], rowvar=False)
161
+ upper_tri = np.triu(corr_matrix, k=1)
162
+ to_keep = np.ones((sum(_prev_feature_mask),), dtype=bool)
163
+ for i in range(upper_tri.shape[0]):
164
+ for j in range(upper_tri.shape[1]):
165
+ if upper_tri[i, j] > self.max_corr:
166
+ to_keep[j] = False
167
+
168
+ _prev_feature_mask[_prev_feature_mask] = to_keep
169
+ return _prev_feature_mask
170
+
171
+ def _get_max_features_mask(
172
+ self, X: np.ndarray, prev_feature_mask: np.ndarray
173
+ ) -> np.ndarray:
174
+ _prev_feature_mask = prev_feature_mask.copy()
175
+ # select features with at least max_var variation
176
+ feature_vars = np.nanvar(X[:, _prev_feature_mask], axis=0)
177
+ order = np.argsort(feature_vars)[: -(self.max_features + 1) : -1]
178
+ keep_feat_idx = np.arange(len(_prev_feature_mask))[order]
179
+ _prev_feature_mask = np.isin(
180
+ np.arange(len(_prev_feature_mask)), keep_feat_idx, assume_unique=True
181
+ )
182
+ return _prev_feature_mask
183
+
184
+ def apply_filter(self, filter, X, prev_feature_mask):
185
+ mask = prev_feature_mask.copy()
186
+ func = self.__getattribute__(f"_get_{filter}_mask")
187
+ feature_keys = self.__getattribute__(f"{filter}__feature_keys")
188
+
189
+ if self.__getattribute__(f"{filter}__independent_keys"):
190
+ for key in feature_keys:
191
+ key_mask = self._curr_keys == key
192
+ mask[key_mask] = func(X[:, key_mask], mask[key_mask])
193
+
194
+ else:
195
+ feature_key_mask = np.isin(self._curr_keys, feature_keys)
196
+ mask[feature_key_mask] = func(
197
+ X[:, feature_key_mask], mask[feature_key_mask]
198
+ )
199
+ return mask
200
+
201
+ def fit(self, X: dict[str, np.ndarray]):
202
+ _X = self.dict_to_array(X)
203
+ feature_mask = np.ones((_X.shape[1]), dtype=bool)
204
+
205
+ # select features with at least min_var variation
206
+ if self.min_var > 0.0:
207
+ if self.min_var__independent_keys:
208
+ for key in self.min_var__feature_keys:
209
+ key_mask = self._curr_keys == key
210
+ feature_mask[key_mask] = self._get_min_var_mask(_X[:, key_mask])
211
+
212
+ else:
213
+ feature_key_mask = np.isin(self._curr_keys, self.min_var__feature_keys)
214
+ feature_mask[feature_key_mask] = self._get_min_var_mask(
215
+ _X[:, feature_key_mask]
216
+ )
217
+
218
+ # select features with at least max_var variation
219
+ if self.max_corr < 1.0:
220
+ if self.max_corr__independent_keys:
221
+ for key in self.max_corr__feature_keys:
222
+ key_mask = self._curr_keys == key
223
+ subset = _X[:, key_mask]
224
+ feature_mask[key_mask] = self._get_max_corr_mask(
225
+ subset, feature_mask[key_mask]
226
+ )
227
+ else:
228
+ feature_key_mask = np.isin(self._curr_keys, self.max_corr__feature_keys)
229
+ feature_mask[feature_key_mask] = self._get_max_corr_mask(
230
+ _X[:, feature_key_mask], feature_mask[feature_key_mask]
231
+ )
232
+
233
+ if self.max_features == 0:
234
+ raise ValueError(
235
+ f"max_features (={self.max_features}) must be -1 or larger 0."
236
+ )
237
+ elif self.max_features > 0:
238
+ if self.max_features__independent_keys:
239
+ for key in self.max_features__feature_keys:
240
+ key_mask = self._curr_keys == key
241
+ feature_mask[key_mask] = self._get_max_features_mask(
242
+ _X[:, key_mask], feature_mask[key_mask]
243
+ )
244
+ else:
245
+ feature_key_mask = np.isin(
246
+ self._curr_keys, self.max_features__feature_keys
247
+ )
248
+ feature_mask[feature_key_mask] = self._get_max_features_mask(
249
+ _X[:, feature_key_mask], feature_mask[feature_key_mask]
250
+ )
251
+
252
+ self._feature_mask = feature_mask
253
+ self.is_fitted_ = True
254
+ return self
255
+
256
+ def transform(self, X: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
257
+ _X = self.dict_to_array(X)
258
+ _X = _X[:, self._feature_mask]
259
+ self._curr_keys = self._curr_keys[self._feature_mask]
260
+ return self.array_to_dict(_X)
261
+
262
+
263
+ class QuantileCreator(FeatureDictMixin, TransformerMixin, BaseEstimator):
264
+ """
265
+ Preprocessor that transforms features into empirical quantiles using ECDFs.
266
+
267
+ This transformer applies an Empirical Cumulative Distribution Function (ECDF)
268
+ to each feature and replaces feature values with their corresponding quantile
269
+ ranks. The transformation is applied independently to each feature type.
270
+
271
+ Both input and output are dictionaries mapping feature types to their
272
+ corresponding feature matrices.
273
+
274
+ Args:
275
+ feature_keys (list[str]): Features to apply quantile creation to.
276
+
277
+ Input:
278
+ dict[str, np.ndarray]: A dictionary where each key corresponds to a feature type
279
+ and each value is a 2D NumPy array of shape (n_samples, n_features).
280
+
281
+ Output:
282
+ dict[str, np.ndarray]: A dictionary with the same keys as the input,
283
+ where each feature value is replaced by its corresponding ECDF quantile rank.
284
+ """
285
+
286
+ def __init__(self, *, feature_keys=None):
287
+ self._ecdfs = None
288
+ super().__init__(feature_keys=feature_keys)
289
+
290
+ def fit(self, X: dict[str, np.ndarray]):
291
+ _X = self.dict_to_array(X)
292
+ ecdfs = []
293
+ for column in range(_X.shape[1]):
294
+ raw_values = _X[:, column].reshape(-1)
295
+ ecdfs.append(ECDF(raw_values))
296
+ self._ecdfs = ecdfs
297
+ self.is_fitted_ = True
298
+ return self
299
+
300
+ def transform(self, X: dict[str, np.ndarray]) -> np.ndarray:
301
+ _X = self.dict_to_array(X)
302
+
303
+ quantiles = np.zeros_like(_X)
304
+ for column in range(_X.shape[1]):
305
+ raw_values = _X[:, column].reshape(-1)
306
+ ecdf = self._ecdfs[column]
307
+ q = ecdf(raw_values)
308
+ quantiles[:, column] = q
309
+
310
+ return self.array_to_dict(quantiles)
311
+
312
+
313
+ class FeaturePreprocessor(TransformerMixin, BaseEstimator):
314
+ """This class implements the feature preprocessing from a dictionary of molecule features."""
315
+
316
+ def __init__(
317
+ self,
318
+ feature_selection_config: dict[str, Any],
319
+ feature_quantilization_config: dict[str, Any],
320
+ descriptors: list[str],
321
+ max_samples: int = -1,
322
+ scaler: str = "standard",
323
+ ):
324
+ self.descriptors = descriptors
325
+
326
+ self.feature_quantilization_config = copy.deepcopy(
327
+ feature_quantilization_config
328
+ )
329
+ self.use_feat_quant = self.feature_quantilization_config.pop("use")
330
+ self.quantile_creator = QuantileCreator(**self.feature_quantilization_config)
331
+
332
+ self.feature_selection_config = copy.deepcopy(feature_selection_config)
333
+ self.use_feat_selec = self.feature_selection_config.pop("use")
334
+ self.feature_selection_config["feature_keys"] = descriptors
335
+ self.feature_selector = FeatureSelector(**self.feature_selection_config)
336
+
337
+ self.max_samples = max_samples
338
+ self.sub_sampler = SubSampler(max_samples=max_samples)
339
+
340
+ self.scaler = SCALER_REGISTRY[scaler]()
341
+
342
+ def __getstate__(self):
343
+ state = super().__getstate__()
344
+ state["quantile_creator"] = self.quantile_creator.__getstate__()
345
+ state["feature_selector"] = self.feature_selector.__getstate__()
346
+ state["sub_sampler"] = self.sub_sampler.__getstate__()
347
+ state["scaler"] = self.scaler.__getstate__()
348
+ return state
349
+
350
+ def __setstate__(self, state):
351
+ _state = copy.deepcopy(state)
352
+ self.quantile_creator.__setstate__(_state.pop("quantile_creator"))
353
+ self.feature_selector.__setstate__(_state.pop("feature_selector"))
354
+ self.sub_sampler.__setstate__(_state.pop("sub_sampler"))
355
+ self.scaler.__setstate__(_state.pop("scaler"))
356
+ super().__setstate__(_state)
357
+
358
+ def get_state(self):
359
+ return self.__getstate__()
360
+
361
+ def set_state(self, state):
362
+ return self.__setstate__(state)
363
+
364
+ def fit(self, X: dict[str, np.ndarray]):
365
+ """Fit the processor transformers"""
366
+ _X = copy.deepcopy(X)
367
+
368
+ if self.use_feat_quant:
369
+ _X = self.quantile_creator.fit_transform(_X)
370
+
371
+ if self.use_feat_selec:
372
+ _X = self.feature_selector.fit_transform(_X)
373
+
374
+ _X = np.concatenate([_X[descr] for descr in self.descriptors], axis=1)
375
+ self.scaler.fit(_X)
376
+ return self
377
+
378
+ def transform(
379
+ self, X: np.ndarray, y: np.ndarray | None = None
380
+ ) -> np.ndarray | tuple[np.ndarray]:
381
+
382
+ _X = X.copy()
383
+ _y = y.copy() if y is not None else None
384
+
385
+ if self.use_feat_quant:
386
+ _X = self.quantile_creator.transform(_X)
387
+ if self.use_feat_selec:
388
+ _X = self.feature_selector.transform(_X)
389
+ _X = np.concatenate([_X[descr] for descr in self.descriptors], axis=1)
390
+ _X = self.scaler.transform(_X)
391
+
392
+ if _y is None:
393
+ _X = self.sub_sampler.transform(_X)
394
+ return _X
395
+
396
+ _X, _y = self.sub_sampler.transform(_X, _y)
397
+ return _X, _y
398
+
399
+
400
+ def create_cleaned_mol_objects(smiles: list[str]) -> tuple[list[Mol], np.ndarray]:
401
+ """This function creates cleaned RDKit mol objects from a list of SMILES.
402
+ Taken from https://huggingface.co/spaces/ml-jku/mhnfs/blob/main/src/data_preprocessing/create_descriptors.py
403
+ Modification by Antonia Ebner:
404
+ - skip uncleanable molecules
405
+ - return clean molecule mask
406
+
407
+ Args:
408
+ smiles (list[str]): list of SMILES
409
+
410
+ Returns:
411
+ list[Mol]: list of cleaned molecules
412
+ np.ndarray[bool]: mask that contains False at index `i`, if molecule in `smiles` at
413
+ index `i` could not be cleaned and was removed.
414
+ """
415
+ sm = Standardizer(canon_taut=True)
416
+
417
+ clean_mol_mask = list()
418
+ mols = list()
419
+ for i, smile in enumerate(smiles):
420
+ mol = Chem.MolFromSmiles(smile)
421
+ standardized_mol, _ = sm.standardize_mol(mol)
422
+ is_cleaned = standardized_mol is not None
423
+ clean_mol_mask.append(is_cleaned)
424
+ if not is_cleaned:
425
+ continue
426
+ can_mol = Chem.MolFromSmiles(Chem.MolToSmiles(standardized_mol))
427
+ mols.append(can_mol)
428
+
429
+ return mols, np.array(clean_mol_mask)
430
+
431
+
432
+ def create_ecfp_fps(mols: list[Mol], radius=3, fpsize=2048, **kwargs) -> np.ndarray:
433
+ """This function ECFP fingerprints for a list of molecules.
434
+ Inspired by from https://huggingface.co/spaces/ml-jku/mhnfs/blob/main/src/data_preprocessing/create_descriptors.py
435
+
436
+ Args:
437
+ mols (list[Mol]): list of molecules
438
+
439
+ Returns:
440
+ np.ndarray: ECFP fingerprints of molecules
441
+ """
442
+ ecfps = list()
443
+
444
+ for mol in mols:
445
+ gen = rdFingerprintGenerator.GetMorganGenerator(
446
+ countSimulation=True, fpSize=fpsize, radius=radius
447
+ )
448
+ fp_sparse_vec = gen.GetCountFingerprint(mol)
449
+
450
+ fp = np.zeros((0,), np.int8)
451
+ DataStructs.ConvertToNumpyArray(fp_sparse_vec, fp)
452
+
453
+ ecfps.append(fp)
454
+
455
+ return np.array(ecfps)
456
+
457
+
458
+ def create_mhfp_fps(
459
+ mols: list[Mol], radius=3, fpsize=2048, seed=42, **kwargs
460
+ ) -> np.ndarray:
461
+ """This function creates MHFP fingerprints for a list of molecules.
462
+ Inspired by from https://huggingface.co/spaces/ml-jku/mhnfs/blob/main/src/data_preprocessing/create_descriptors.py
463
+
464
+ Args:
465
+ mols (list[Mol]): list of molecules
466
+
467
+ Returns:
468
+ np.ndarray: ECFP fingerprints of molecules
469
+ """
470
+ mhfps = list()
471
+
472
+ enc = MHFPEncoder(fpsize, seed)
473
+
474
+ for mol in mols:
475
+ hash_values = np.array(enc.EncodeMol(mol, radius=radius))
476
+ folded = np.zeros(fpsize, dtype=np.uint8)
477
+
478
+ if len(hash_values) > 0:
479
+ folded[hash_values % fpsize] = 1
480
+
481
+ mhfps.append(folded)
482
+
483
+ return np.array(mhfps)
484
+
485
+
486
+ def create_maccs_keys(mols: list[Mol]) -> np.ndarray:
487
+ """This function creates MACCS keys for a list of molecules.
488
+
489
+ Args:
490
+ mols (list[Mol]): list of molecules
491
+
492
+ Returns:
493
+ np.ndarray: MACCS keys of molecules
494
+ """
495
+ maccs = [MACCSkeys.GenMACCSKeys(x) for x in mols]
496
+ return np.array(maccs)
497
+
498
+
499
+ def get_tox_patterns(filepath: str):
500
+ """This retrieves the tox features defined in filepath.
501
+ Args:
502
+ filepath (str): A list of tox features
503
+ """
504
+ # load patterns
505
+ with open(filepath) as f:
506
+ smarts_list = [s[1] for s in json.load(f)]
507
+
508
+ # Code does not work for this case
509
+ assert len([s for s in smarts_list if ("AND" in s) and ("OR" in s)]) == 0
510
+
511
+ # Chem.MolFromSmarts takes a long time so it pays of to parse all the smarts first
512
+ # and then use them for all molecules. This gives a huge speedup over existing code.
513
+ # a list of patterns, whether to negate the match result and how to join them to obtain one boolean value
514
+ all_patterns = []
515
+ for smarts in smarts_list:
516
+ patterns = [] # list of smarts-patterns
517
+ # value for each of the patterns above. Negates the values of the above later.
518
+ negations = []
519
+
520
+ if " AND " in smarts:
521
+ smarts = smarts.split(" AND ")
522
+ merge_any = False # If an ' AND ' is found all 'subsmarts' have to match
523
+ else:
524
+ # If there is an ' OR ' present it's enough is any of the 'subsmarts' match.
525
+ # This also accumulates smarts where neither ' OR ' nor ' AND ' occur
526
+ smarts = smarts.split(" OR ")
527
+ merge_any = True
528
+
529
+ # for all subsmarts check if they are preceded by 'NOT '
530
+ for s in smarts:
531
+ neg = s.startswith("NOT ")
532
+ if neg:
533
+ s = s[4:]
534
+ patterns.append(Chem.MolFromSmarts(s))
535
+ negations.append(neg)
536
+
537
+ all_patterns.append((patterns, negations, merge_any))
538
+ return all_patterns
539
+
540
+
541
+ def create_tox_features(mols: list[Mol], patterns: list) -> np.ndarray:
542
+ """Matches the tox patterns against a molecule. Returns a boolean array"""
543
+ tox_data = []
544
+ for mol in mols:
545
+ mol_features = []
546
+ for patts, negations, merge_any in patterns:
547
+ matches = [mol.HasSubstructMatch(p) for p in patts]
548
+ matches = [m != n for m, n in zip(matches, negations)]
549
+ if merge_any:
550
+ pres = any(matches)
551
+ else:
552
+ pres = all(matches)
553
+ mol_features.append(pres)
554
+
555
+ tox_data.append(np.array(mol_features))
556
+
557
+ return np.array(tox_data)
558
+
559
+
560
+ def create_rdkit_descriptors(mols: list[Mol]) -> np.ndarray:
561
+ """This function creates RDKit descriptors for a list of molecules.
562
+ Taken from https://huggingface.co/spaces/ml-jku/mhnfs/blob/main/src/data_preprocessing/create_descriptors.py
563
+
564
+ Args:
565
+ mols (list[Mol]): list of molecules
566
+
567
+ Returns:
568
+ np.ndarray: RDKit descriptors of molecules
569
+ """
570
+ rdkit_descriptors = list()
571
+
572
+ for mol in mols:
573
+ descrs = []
574
+ for _, descr_calc_fn in Descriptors._descList:
575
+ descrs.append(descr_calc_fn(mol))
576
+
577
+ descrs = np.array(descrs)
578
+ descrs = descrs[USED_200_DESCR]
579
+ rdkit_descriptors.append(descrs)
580
+
581
+ return np.array(rdkit_descriptors)
582
+
583
+
584
+ def create_quantiles(raw_features: np.ndarray, ecdfs: list) -> np.ndarray:
585
+ """Create quantile values for given features using the columns
586
+ Taken from https://huggingface.co/spaces/ml-jku/mhnfs/blob/main/src/data_preprocessing/create_descriptors.py
587
+
588
+ Args:
589
+ raw_features (np.ndarray): values to put into quantiles
590
+ ecdfs (list): ECDFs to use
591
+
592
+ Returns:
593
+ np.ndarray: computed quantiles
594
+ """
595
+ quantiles = np.zeros_like(raw_features)
596
+
597
+ for column in range(raw_features.shape[1]):
598
+ raw_values = raw_features[:, column].reshape(-1)
599
+ ecdf = ecdfs[column]
600
+ q = ecdf(raw_values)
601
+ quantiles[:, column] = q
602
+
603
+ return quantiles
604
+
605
+
606
+ def fill(features, mask, value=np.nan):
607
+ n_mols = len(mask)
608
+ n_features = features.shape[1]
609
+
610
+ data = np.zeros(shape=(n_mols, n_features))
611
+ data.fill(value)
612
+ data[~mask] = features
613
+ return data
614
+
615
+
616
+ def create_descriptors(
617
+ smiles,
618
+ descriptors,
619
+ **ecfp_kwargs,
620
+ ):
621
+ """Generate molecular descriptors for multiple SMILES strings.
622
+ Inspired by https://huggingface.co/spaces/ml-jku/mhnfs/blob/main/src/data_preprocessing/create_descriptors.py
623
+
624
+ Each SMILES is processed and sanitized using RDKit.
625
+ SMILES that cannot be sanitized are encoded with NaNs, and a corresponding boolean mask
626
+ is returned to indicate which inputs were successfully processed.
627
+
628
+ Args:
629
+ smiles (list[str]): List of SMILES strings for which to generate descriptors.
630
+ descriptors (list[str]): List of descriptor types to compute.
631
+ Supported values include:
632
+ ['ecfps', 'tox', 'maccs', 'rdkit_descrs'].
633
+
634
+ Returns:
635
+ tuple[dict[str, np.ndarray], np.ndarray]:
636
+ - A dictionary mapping descriptor names to their computed arrays.
637
+ - A boolean mask of shape (len(smiles),) indicating which SMILES
638
+ were successfully sanitized and processed.
639
+ """
640
+ # Create cleanded rdkit mol objects
641
+ mols, clean_mol_mask = create_cleaned_mol_objects(smiles)
642
+ print(f"Cleaned molecules, {(~clean_mol_mask).sum()} could not be sanitized")
643
+
644
+ # Create fingerprints and descriptors
645
+ if "mhfps" in descriptors:
646
+ mhfps = create_mhfp_fps(mols, **ecfp_kwargs)
647
+ mhfps = fill(mhfps, ~clean_mol_mask)
648
+ print("Created MHFP fingerprints")
649
+
650
+ if "ecfps" in descriptors:
651
+ ecfps = create_ecfp_fps(mols, **ecfp_kwargs)
652
+ ecfps = fill(ecfps, ~clean_mol_mask)
653
+ print("Created ECFP fingerprints")
654
+
655
+ if "tox" in descriptors:
656
+ tox_patterns = get_tox_patterns(TOX_SMARTS_PATH)
657
+ tox = create_tox_features(mols, tox_patterns)
658
+ tox = fill(tox, ~clean_mol_mask)
659
+ print("Created Tox features")
660
+
661
+ if "maccs" in descriptors:
662
+ maccs = create_maccs_keys(mols)
663
+ maccs = fill(maccs, ~clean_mol_mask)
664
+ print("Created MACCS keys")
665
+
666
+ if "rdkit_descrs" in descriptors:
667
+ rdkit_descrs = create_rdkit_descriptors(mols)
668
+ rdkit_descrs = fill(rdkit_descrs, ~clean_mol_mask)
669
+ print("Created RDKit descriptors")
670
+
671
+ # concatenate features
672
+ features = {}
673
+ for descr in descriptors:
674
+ features[descr] = vars()[descr]
675
+
676
+ return features, clean_mol_mask
677
+
678
+
679
+ def get_tox21_split(token, cvfold=None):
680
+ """Retrieve Tox21 splits from HuggingFace with respect to given cvfold."""
681
+ ds = load_dataset("ml-jku/tox21", token=token)
682
+
683
+ train_df = ds["train"].to_pandas()
684
+ val_df = ds["validation"].to_pandas()
685
+
686
+ if cvfold is None:
687
+ return {"train": train_df, "validation": val_df}
688
+
689
+ combined_df = pd.concat([train_df, val_df], ignore_index=True)
690
+ cvfold = float(cvfold)
691
+
692
+ # create new splits
693
+ cvfold = float(cvfold)
694
+ train_df = combined_df[combined_df.CVfold != cvfold]
695
+ val_df = combined_df[combined_df.CVfold == cvfold]
696
+
697
+ # exclude train mols that occur in the validation split
698
+ val_inchikeys = set(val_df["inchikey"])
699
+ train_df = train_df[~train_df["inchikey"].isin(val_inchikeys)]
700
+
701
+ return {
702
+ "train": train_df.reset_index(drop=True),
703
+ "validation": val_df.reset_index(drop=True),
704
+ }
src/utils.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## These MolStandardizer classes are due to Paolo Tosco
2
+ ## It was taken from the FS-Mol github
3
+ ## (https://github.com/microsoft/FS-Mol/blob/main/fs_mol/preprocessing/utils/
4
+ ## standardizer.py)
5
+ ## They ensure that a sequence of standardization operations are applied
6
+ ## https://gist.github.com/ptosco/7e6b9ab9cc3e44ba0919060beaed198e
7
+
8
+ import os
9
+ import pickle
10
+ from typing import Any
11
+
12
+ import numpy as np
13
+
14
+ from rdkit import Chem
15
+ from rdkit.Chem.MolStandardize import rdMolStandardize
16
+
17
+ HF_TOKEN = os.environ.get("HF_TOKEN")
18
+ TOX_SMARTS_PATH = "data/tox_smarts.json"
19
+
20
+ TASKS = [
21
+ "NR-AR",
22
+ "NR-AR-LBD",
23
+ "NR-AhR",
24
+ "NR-Aromatase",
25
+ "NR-ER",
26
+ "NR-ER-LBD",
27
+ "NR-PPAR-gamma",
28
+ "SR-ARE",
29
+ "SR-ATAD5",
30
+ "SR-HSE",
31
+ "SR-MMP",
32
+ "SR-p53",
33
+ ]
34
+
35
+ USED_200_DESCR = [
36
+ 0,
37
+ 1,
38
+ 2,
39
+ 3,
40
+ 4,
41
+ 5,
42
+ 6,
43
+ 7,
44
+ 8,
45
+ 9,
46
+ 10,
47
+ 11,
48
+ 12,
49
+ 13,
50
+ 14,
51
+ 15,
52
+ 16,
53
+ 25,
54
+ 26,
55
+ 27,
56
+ 28,
57
+ 29,
58
+ 30,
59
+ 31,
60
+ 32,
61
+ 33,
62
+ 34,
63
+ 35,
64
+ 36,
65
+ 37,
66
+ 38,
67
+ 39,
68
+ 40,
69
+ 41,
70
+ 42,
71
+ 43,
72
+ 44,
73
+ 45,
74
+ 46,
75
+ 47,
76
+ 48,
77
+ 49,
78
+ 50,
79
+ 51,
80
+ 52,
81
+ 53,
82
+ 54,
83
+ 55,
84
+ 56,
85
+ 57,
86
+ 58,
87
+ 59,
88
+ 60,
89
+ 61,
90
+ 62,
91
+ 63,
92
+ 64,
93
+ 65,
94
+ 66,
95
+ 67,
96
+ 68,
97
+ 69,
98
+ 70,
99
+ 71,
100
+ 72,
101
+ 73,
102
+ 74,
103
+ 75,
104
+ 76,
105
+ 77,
106
+ 78,
107
+ 79,
108
+ 80,
109
+ 81,
110
+ 82,
111
+ 83,
112
+ 84,
113
+ 85,
114
+ 86,
115
+ 87,
116
+ 88,
117
+ 89,
118
+ 90,
119
+ 91,
120
+ 92,
121
+ 93,
122
+ 94,
123
+ 95,
124
+ 96,
125
+ 97,
126
+ 98,
127
+ 99,
128
+ 100,
129
+ 101,
130
+ 102,
131
+ 103,
132
+ 104,
133
+ 105,
134
+ 106,
135
+ 107,
136
+ 108,
137
+ 109,
138
+ 110,
139
+ 111,
140
+ 112,
141
+ 113,
142
+ 114,
143
+ 115,
144
+ 116,
145
+ 117,
146
+ 118,
147
+ 119,
148
+ 120,
149
+ 121,
150
+ 122,
151
+ 123,
152
+ 124,
153
+ 125,
154
+ 126,
155
+ 127,
156
+ 128,
157
+ 129,
158
+ 130,
159
+ 131,
160
+ 132,
161
+ 133,
162
+ 134,
163
+ 135,
164
+ 136,
165
+ 137,
166
+ 138,
167
+ 139,
168
+ 140,
169
+ 141,
170
+ 142,
171
+ 143,
172
+ 144,
173
+ 145,
174
+ 146,
175
+ 147,
176
+ 148,
177
+ 149,
178
+ 150,
179
+ 151,
180
+ 152,
181
+ 153,
182
+ 154,
183
+ 155,
184
+ 156,
185
+ 157,
186
+ 158,
187
+ 159,
188
+ 160,
189
+ 161,
190
+ 162,
191
+ 163,
192
+ 164,
193
+ 165,
194
+ 166,
195
+ 167,
196
+ 168,
197
+ 169,
198
+ 170,
199
+ 171,
200
+ 172,
201
+ 173,
202
+ 174,
203
+ 175,
204
+ 176,
205
+ 177,
206
+ 178,
207
+ 179,
208
+ 180,
209
+ 181,
210
+ 182,
211
+ 183,
212
+ 184,
213
+ 185,
214
+ 186,
215
+ 187,
216
+ 188,
217
+ 189,
218
+ 190,
219
+ 191,
220
+ 192,
221
+ 193,
222
+ 194,
223
+ 195,
224
+ 196,
225
+ 197,
226
+ 198,
227
+ 199,
228
+ 200,
229
+ 201,
230
+ 202,
231
+ 203,
232
+ 204,
233
+ 205,
234
+ 206,
235
+ 207,
236
+ ]
237
+
238
+
239
+ class Standardizer:
240
+ """
241
+ Simple wrapper class around rdkit Standardizer.
242
+ """
243
+
244
+ DEFAULT_CANON_TAUT = False
245
+ DEFAULT_METAL_DISCONNECT = False
246
+ MAX_TAUTOMERS = 100
247
+ MAX_TRANSFORMS = 100
248
+ MAX_RESTARTS = 200
249
+ PREFER_ORGANIC = True
250
+
251
+ def __init__(
252
+ self,
253
+ metal_disconnect=None,
254
+ canon_taut=None,
255
+ ):
256
+ """
257
+ Constructor.
258
+ All parameters are optional.
259
+ :param metal_disconnect: if True, metallorganic complexes are
260
+ disconnected
261
+ :param canon_taut: if True, molecules are converted to their
262
+ canonical tautomer
263
+ """
264
+ super().__init__()
265
+ if metal_disconnect is None:
266
+ metal_disconnect = self.DEFAULT_METAL_DISCONNECT
267
+ if canon_taut is None:
268
+ canon_taut = self.DEFAULT_CANON_TAUT
269
+ self._canon_taut = canon_taut
270
+ self._metal_disconnect = metal_disconnect
271
+ self._taut_enumerator = None
272
+ self._uncharger = None
273
+ self._lfrag_chooser = None
274
+ self._metal_disconnector = None
275
+ self._normalizer = None
276
+ self._reionizer = None
277
+ self._params = None
278
+
279
+ @property
280
+ def params(self):
281
+ """Return the MolStandardize CleanupParameters."""
282
+ if self._params is None:
283
+ self._params = rdMolStandardize.CleanupParameters()
284
+ self._params.maxTautomers = self.MAX_TAUTOMERS
285
+ self._params.maxTransforms = self.MAX_TRANSFORMS
286
+ self._params.maxRestarts = self.MAX_RESTARTS
287
+ self._params.preferOrganic = self.PREFER_ORGANIC
288
+ self._params.tautomerRemoveSp3Stereo = False
289
+ return self._params
290
+
291
+ @property
292
+ def canon_taut(self):
293
+ """Return whether tautomer canonicalization will be done."""
294
+ return self._canon_taut
295
+
296
+ @property
297
+ def metal_disconnect(self):
298
+ """Return whether metallorganic complexes will be disconnected."""
299
+ return self._metal_disconnect
300
+
301
+ @property
302
+ def taut_enumerator(self):
303
+ """Return the TautomerEnumerator object."""
304
+ if self._taut_enumerator is None:
305
+ self._taut_enumerator = rdMolStandardize.TautomerEnumerator(self.params)
306
+ return self._taut_enumerator
307
+
308
+ @property
309
+ def uncharger(self):
310
+ """Return the Uncharger object."""
311
+ if self._uncharger is None:
312
+ self._uncharger = rdMolStandardize.Uncharger()
313
+ return self._uncharger
314
+
315
+ @property
316
+ def lfrag_chooser(self):
317
+ """Return the LargestFragmentChooser object."""
318
+ if self._lfrag_chooser is None:
319
+ self._lfrag_chooser = rdMolStandardize.LargestFragmentChooser(
320
+ self.params.preferOrganic
321
+ )
322
+ return self._lfrag_chooser
323
+
324
+ @property
325
+ def metal_disconnector(self):
326
+ """Return the MetalDisconnector object."""
327
+ if self._metal_disconnector is None:
328
+ self._metal_disconnector = rdMolStandardize.MetalDisconnector()
329
+ return self._metal_disconnector
330
+
331
+ @property
332
+ def normalizer(self):
333
+ """Return the Normalizer object."""
334
+ if self._normalizer is None:
335
+ self._normalizer = rdMolStandardize.Normalizer(
336
+ self.params.normalizationsFile, self.params.maxRestarts
337
+ )
338
+ return self._normalizer
339
+
340
+ @property
341
+ def reionizer(self):
342
+ """Return the Reionizer object."""
343
+ if self._reionizer is None:
344
+ self._reionizer = rdMolStandardize.Reionizer(self.params.acidbaseFile)
345
+ return self._reionizer
346
+
347
+ def charge_parent(self, mol_in):
348
+ """Sequentially apply a series of MolStandardize operations:
349
+ * MetalDisconnector
350
+ * Normalizer
351
+ * Reionizer
352
+ * LargestFragmentChooser
353
+ * Uncharger
354
+ The net result is that a desalted, normalized, neutral
355
+ molecule with implicit Hs is returned.
356
+ """
357
+ params = Chem.RemoveHsParameters()
358
+ params.removeAndTrackIsotopes = True
359
+ mol_in = Chem.RemoveHs(mol_in, params, sanitize=False)
360
+ if self._metal_disconnect:
361
+ mol_in = self.metal_disconnector.Disconnect(mol_in)
362
+ normalized = self.normalizer.normalize(mol_in)
363
+ Chem.SanitizeMol(normalized)
364
+ normalized = self.reionizer.reionize(normalized)
365
+ Chem.AssignStereochemistry(normalized)
366
+ normalized = self.lfrag_chooser.choose(normalized)
367
+ normalized = self.uncharger.uncharge(normalized)
368
+ # need this to reassess aromaticity on things like
369
+ # cyclopentadienyl, tropylium, azolium, etc.
370
+ Chem.SanitizeMol(normalized)
371
+ return Chem.RemoveHs(Chem.AddHs(normalized))
372
+
373
+ def standardize_mol(self, mol_in):
374
+ """
375
+ Standardize a single molecule.
376
+ :param mol_in: a Chem.Mol
377
+ :return: * (standardized Chem.Mol, n_taut) tuple
378
+ if success. n_taut will be negative if
379
+ tautomer enumeration was aborted due
380
+ to reaching a limit
381
+ * (None, error_msg) if failure
382
+ This calls self.charge_parent() and, if self._canon_taut
383
+ is True, runs tautomer canonicalization.
384
+ """
385
+ n_tautomers = 0
386
+ if isinstance(mol_in, Chem.Mol):
387
+ name = None
388
+ try:
389
+ name = mol_in.GetProp("_Name")
390
+ except KeyError:
391
+ pass
392
+ if not name:
393
+ name = "NONAME"
394
+ else:
395
+ error = f"Expected SMILES or Chem.Mol as input, got {str(type(mol_in))}"
396
+ return None, error
397
+ try:
398
+ mol_out = self.charge_parent(mol_in)
399
+ except Exception as e:
400
+ error = f"charge_parent FAILED: {str(e).strip()}"
401
+ return None, error
402
+ if self._canon_taut:
403
+ try:
404
+ res = self.taut_enumerator.Enumerate(mol_out, False)
405
+ except TypeError:
406
+ # we are still on the pre-2021 RDKit API
407
+ res = self.taut_enumerator.Enumerate(mol_out)
408
+ except Exception as e:
409
+ # something else went wrong
410
+ error = f"canon_taut FAILED: {str(e).strip()}"
411
+ return None, error
412
+ n_tautomers = len(res)
413
+ if hasattr(res, "status"):
414
+ completed = (
415
+ res.status == rdMolStandardize.TautomerEnumeratorStatus.Completed
416
+ )
417
+ else:
418
+ # we are still on the pre-2021 RDKit API
419
+ completed = len(res) < 1000
420
+ if not completed:
421
+ n_tautomers = -n_tautomers
422
+ try:
423
+ mol_out = self.taut_enumerator.PickCanonical(res)
424
+ except AttributeError:
425
+ # we are still on the pre-2021 RDKit API
426
+ mol_out = max(
427
+ [(self.taut_enumerator.ScoreTautomer(m), m) for m in res]
428
+ )[1]
429
+ except Exception as e:
430
+ # something else went wrong
431
+ error = f"canon_taut FAILED: {str(e).strip()}"
432
+ return None, error
433
+ mol_out.SetProp("_Name", name)
434
+ return mol_out, n_tautomers
435
+
436
+
437
+ class FeatureDictMixin:
438
+ """
439
+ Mixin that enables bidirectional handling of dict-based multi-feature inputs.
440
+ Allows selective removal of columns directly from the combined array.
441
+
442
+ Example input:
443
+ {
444
+ "ecfps": np.ndarray,
445
+ "tox": np.ndarray,
446
+ }
447
+ """
448
+
449
+ def __init__(self, feature_keys=None):
450
+ self.feature_keys = feature_keys
451
+ self._curr_keys = None
452
+ self._unused_data = None
453
+
454
+ def dict_to_array(self, input: dict[Any, np.ndarray]) -> np.ndarray:
455
+ """Parse dict input and concatenate into a single array."""
456
+ if not isinstance(input, dict):
457
+ raise TypeError("Input must be a dict {feature_type: np.ndarray, ...}")
458
+
459
+ self._unused_data = {}
460
+ remaining_input = {}
461
+ for key in list(input.keys()):
462
+ if key not in self.feature_keys:
463
+ self._unused_data[key] = input[key]
464
+ else:
465
+ remaining_input[key] = input[key]
466
+
467
+ curr_keys = []
468
+ output = []
469
+ for key in self.feature_keys:
470
+ array = remaining_input.pop(key)
471
+ if array.ndim != 2:
472
+ raise ValueError(f"Feature '{key}' must be 2D, got shape {array.shape}")
473
+
474
+ curr_keys.extend([key] * array.shape[1])
475
+ output.append(array)
476
+
477
+ self._curr_keys = np.array(curr_keys)
478
+
479
+ return np.concatenate(output, axis=1)
480
+
481
+ def array_to_dict(self, input: np.ndarray) -> dict[Any, np.ndarray]:
482
+ """Reconstruct dict from a concatenated array."""
483
+ if self._curr_keys is None:
484
+ raise ValueError("No feature mapping stored. Did you call parse_input()?")
485
+
486
+ output = {key: input[:, self._curr_keys == key] for key in self.feature_keys}
487
+ output.update(self._unused_data)
488
+
489
+ self._curr_keys = None
490
+ self._unused_data = None
491
+ return output
492
+
493
+
494
+ def load_pickle(path: str):
495
+ with open(path, "rb") as file:
496
+ content = pickle.load(file)
497
+ return content
498
+
499
+
500
+ def write_pickle(path: str, obj: object):
501
+ with open(path, "wb") as file:
502
+ pickle.dump(obj, file)
503
+
504
+
505
+ def create_dir(path, is_file=False):
506
+ """Creates the parent directories if a path to a file is given, else create the given directory"""
507
+
508
+ to_create = os.path.dirname(path) if is_file else path
509
+ if not os.path.exists(to_create):
510
+ os.makedirs(to_create)
511
+
512
+
513
+ def normalize_config(config: dict):
514
+ """Normalizes a json config recursively by applying a mapping"""
515
+ mapping = {"none": None, "true": True, "false": False}
516
+ new_config = {}
517
+ for key, val in config.items():
518
+ if isinstance(val, dict):
519
+ new_config[key] = normalize_config(val)
520
+ elif isinstance(val, (int, float, str)) and val in mapping:
521
+ new_config[key] = mapping[val]
522
+ else:
523
+ new_config[key] = val
524
+ return new_config
train.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script for fitting and saving any preprocessing assets, as well as the fitted RF model
3
+ """
4
+
5
+ import os
6
+ import json
7
+ import random
8
+ import logging
9
+ import argparse
10
+
11
+ import joblib
12
+ import numpy as np
13
+ from datetime import datetime
14
+
15
+ from src.model import Tox21RFClassifier
16
+ from src.preprocess import FeaturePreprocessor
17
+ from src.utils import create_dir, normalize_config
18
+
19
+ parser = argparse.ArgumentParser(description="RF Training script for Tox21 dataset")
20
+
21
+ parser.add_argument(
22
+ "--config",
23
+ type=str,
24
+ default="config/config.json",
25
+ )
26
+
27
+
28
+ def main(config):
29
+ timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
30
+
31
+ # setup logger
32
+ logger = logging.getLogger(__name__)
33
+ script_name = os.path.splitext(os.path.basename(__file__))[0]
34
+ logging.basicConfig(
35
+ level=logging.INFO,
36
+ format="%(asctime)s [%(levelname)s] %(message)s",
37
+ handlers=[
38
+ logging.FileHandler(
39
+ os.path.join(
40
+ config["log_folder"],
41
+ f"{script_name}_{timestamp}.log",
42
+ )
43
+ ),
44
+ logging.StreamHandler(),
45
+ ],
46
+ )
47
+
48
+ logger.info(f"Config: {config}")
49
+ model_config_repr = "Model config: \n" + "\n".join(
50
+ [str(val) for val in config["model_config"].values()]
51
+ )
52
+ logger.info(f"Model config: \n{model_config_repr}")
53
+
54
+ # seeding
55
+ random.seed(config["seed"])
56
+ np.random.seed(config["seed"])
57
+
58
+ train_data = np.load(os.path.join(config["data_folder"], "tox21_train_cv4.npz"))
59
+ val_data = np.load(os.path.join(config["data_folder"], "tox21_validation_cv4.npz"))
60
+
61
+ # filter out unsanitized molecules
62
+ train_is_clean = train_data["clean_mol_mask"]
63
+ val_is_clean = val_data["clean_mol_mask"]
64
+ train_data = {descr: array[train_is_clean] for descr, array in train_data.items()}
65
+ val_data = {descr: array[val_is_clean] for descr, array in val_data.items()}
66
+
67
+ if config["merge_train_val"]:
68
+ data = {
69
+ descr: np.concatenate([train_data[descr], val_data[descr]], axis=0)
70
+ for descr in config["descriptors"]
71
+ }
72
+ labels = np.concatenate([train_data["labels"], val_data["labels"]], axis=0)
73
+ else:
74
+ data = {descr: train_data[descr] for descr in config["descriptors"]}
75
+ labels = train_data["labels"]
76
+
77
+ if config["ckpt_path"]:
78
+ logger.info(
79
+ f"Fitted RandomForestClassifier will be saved as: {config['ckpt_path']}"
80
+ )
81
+ else:
82
+ logger.info("Fitted RandomForestClassifier will NOT be saved.")
83
+
84
+ model = Tox21RFClassifier(seed=config["seed"], config=config["model_config"])
85
+
86
+ # setup processors
87
+ preprocessor = FeaturePreprocessor(
88
+ feature_selection_config=config["feature_selection"],
89
+ feature_quantilization_config=config["feature_quantilization"],
90
+ descriptors=config["descriptors"],
91
+ max_samples=config["max_samples"],
92
+ scaler=config["scaler"],
93
+ )
94
+ preprocessor.fit(data)
95
+
96
+ logger.info("Start training.")
97
+ for i, task in enumerate(model.tasks):
98
+ task_labels = labels[:, i]
99
+ label_mask = ~np.isnan(task_labels)
100
+ logger.info(f"Fit task {task} using {sum(label_mask)} samples")
101
+
102
+ task_data = {key: val[label_mask] for key, val in data.items()}
103
+ task_labels = task_labels[label_mask].astype(int)
104
+
105
+ task_data = preprocessor.transform(task_data)
106
+ model.fit(task, task_data, task_labels)
107
+ if config["debug"]:
108
+ break
109
+
110
+ log_text = f"Finished training."
111
+ logger.info(log_text)
112
+
113
+ if config["ckpt_path"]:
114
+ model.save(config["ckpt_path"])
115
+ logger.info(f"Save model as: {config['ckpt_path']}")
116
+
117
+ if config["preprocessor_path"]:
118
+ state = preprocessor.get_state()
119
+ joblib.dump(state, config["preprocessor_path"])
120
+ logger.info(f"Save preprocessor as: {config['preprocessor_path']}")
121
+
122
+
123
+ if __name__ == "__main__":
124
+ args = parser.parse_args()
125
+
126
+ with open(args.config, "r") as f:
127
+ config = json.load(f)
128
+ config = normalize_config(config)
129
+
130
+ create_dir(config["log_folder"])
131
+
132
+ main(config)
uv.lock ADDED
The diff for this file is too large to render. See raw diff