Spaces:
Sleeping
Sleeping
Added first version
Browse files- .example.env +1 -0
- .gitignore +6 -0
- Dockerfile +16 -0
- LICENSE +407 -0
- README.md +115 -12
- app.py +77 -0
- config/config.json +31 -0
- docs/proposed_lightgbm_framework.md +203 -0
- predict.py +125 -0
- requirements.txt +13 -0
- src/constants.py +16 -0
- src/features.py +125 -0
- src/lightgbm_trainer.py +235 -0
- src/model.py +52 -0
- src/preprocess.py +65 -0
- src/seed.py +10 -0
- src/stage_two.py +106 -0
- src/train_evaluate.py +116 -0
- train.py +119 -0
.example.env
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
TOKEN=example_token
|
.gitignore
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
predict copy.py
|
| 2 |
+
hp_search/logs/*
|
| 3 |
+
hp_search/models/*
|
| 4 |
+
__pycache__
|
| 5 |
+
.env
|
| 6 |
+
notes.txt
|
Dockerfile
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
|
| 2 |
+
# you will also find guides on how best to write your Dockerfile
|
| 3 |
+
|
| 4 |
+
FROM python:3.11.4
|
| 5 |
+
|
| 6 |
+
RUN useradd -m -u 1000 user
|
| 7 |
+
USER user
|
| 8 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
| 9 |
+
|
| 10 |
+
WORKDIR /app
|
| 11 |
+
|
| 12 |
+
COPY --chown=user ./requirements.txt requirements.txt
|
| 13 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
| 14 |
+
|
| 15 |
+
COPY --chown=user . /app
|
| 16 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
LICENSE
ADDED
|
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Attribution-NonCommercial 4.0 International
|
| 2 |
+
|
| 3 |
+
=======================================================================
|
| 4 |
+
|
| 5 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
| 6 |
+
does not provide legal services or legal advice. Distribution of
|
| 7 |
+
Creative Commons public licenses does not create a lawyer-client or
|
| 8 |
+
other relationship. Creative Commons makes its licenses and related
|
| 9 |
+
information available on an "as-is" basis. Creative Commons gives no
|
| 10 |
+
warranties regarding its licenses, any material licensed under their
|
| 11 |
+
terms and conditions, or any related information. Creative Commons
|
| 12 |
+
disclaims all liability for damages resulting from their use to the
|
| 13 |
+
fullest extent possible.
|
| 14 |
+
|
| 15 |
+
Using Creative Commons Public Licenses
|
| 16 |
+
|
| 17 |
+
Creative Commons public licenses provide a standard set of terms and
|
| 18 |
+
conditions that creators and other rights holders may use to share
|
| 19 |
+
original works of authorship and other material subject to copyright
|
| 20 |
+
and certain other rights specified in the public license below. The
|
| 21 |
+
following considerations are for informational purposes only, are not
|
| 22 |
+
exhaustive, and do not form part of our licenses.
|
| 23 |
+
|
| 24 |
+
Considerations for licensors: Our public licenses are
|
| 25 |
+
intended for use by those authorized to give the public
|
| 26 |
+
permission to use material in ways otherwise restricted by
|
| 27 |
+
copyright and certain other rights. Our licenses are
|
| 28 |
+
irrevocable. Licensors should read and understand the terms
|
| 29 |
+
and conditions of the license they choose before applying it.
|
| 30 |
+
Licensors should also secure all rights necessary before
|
| 31 |
+
applying our licenses so that the public can reuse the
|
| 32 |
+
material as expected. Licensors should clearly mark any
|
| 33 |
+
material not subject to the license. This includes other CC-
|
| 34 |
+
licensed material, or material used under an exception or
|
| 35 |
+
limitation to copyright. More considerations for licensors:
|
| 36 |
+
wiki.creativecommons.org/Considerations_for_licensors
|
| 37 |
+
|
| 38 |
+
Considerations for the public: By using one of our public
|
| 39 |
+
licenses, a licensor grants the public permission to use the
|
| 40 |
+
licensed material under specified terms and conditions. If
|
| 41 |
+
the licensor's permission is not necessary for any reason--for
|
| 42 |
+
example, because of any applicable exception or limitation to
|
| 43 |
+
copyright--then that use is not regulated by the license. Our
|
| 44 |
+
licenses grant only permissions under copyright and certain
|
| 45 |
+
other rights that a licensor has authority to grant. Use of
|
| 46 |
+
the licensed material may still be restricted for other
|
| 47 |
+
reasons, including because others have copyright or other
|
| 48 |
+
rights in the material. A licensor may make special requests,
|
| 49 |
+
such as asking that all changes be marked or described.
|
| 50 |
+
Although not required by our licenses, you are encouraged to
|
| 51 |
+
respect those requests where reasonable. More considerations
|
| 52 |
+
for the public:
|
| 53 |
+
wiki.creativecommons.org/Considerations_for_licensees
|
| 54 |
+
|
| 55 |
+
=======================================================================
|
| 56 |
+
|
| 57 |
+
Creative Commons Attribution-NonCommercial 4.0 International Public
|
| 58 |
+
License
|
| 59 |
+
|
| 60 |
+
By exercising the Licensed Rights (defined below), You accept and agree
|
| 61 |
+
to be bound by the terms and conditions of this Creative Commons
|
| 62 |
+
Attribution-NonCommercial 4.0 International Public License ("Public
|
| 63 |
+
License"). To the extent this Public License may be interpreted as a
|
| 64 |
+
contract, You are granted the Licensed Rights in consideration of Your
|
| 65 |
+
acceptance of these terms and conditions, and the Licensor grants You
|
| 66 |
+
such rights in consideration of benefits the Licensor receives from
|
| 67 |
+
making the Licensed Material available under these terms and
|
| 68 |
+
conditions.
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
Section 1 -- Definitions.
|
| 72 |
+
|
| 73 |
+
a. Adapted Material means material subject to Copyright and Similar
|
| 74 |
+
Rights that is derived from or based upon the Licensed Material
|
| 75 |
+
and in which the Licensed Material is translated, altered,
|
| 76 |
+
arranged, transformed, or otherwise modified in a manner requiring
|
| 77 |
+
permission under the Copyright and Similar Rights held by the
|
| 78 |
+
Licensor. For purposes of this Public License, where the Licensed
|
| 79 |
+
Material is a musical work, performance, or sound recording,
|
| 80 |
+
Adapted Material is always produced where the Licensed Material is
|
| 81 |
+
synched in timed relation with a moving image.
|
| 82 |
+
|
| 83 |
+
b. Adapter's License means the license You apply to Your Copyright
|
| 84 |
+
and Similar Rights in Your contributions to Adapted Material in
|
| 85 |
+
accordance with the terms and conditions of this Public License.
|
| 86 |
+
|
| 87 |
+
c. Copyright and Similar Rights means copyright and/or similar rights
|
| 88 |
+
closely related to copyright including, without limitation,
|
| 89 |
+
performance, broadcast, sound recording, and Sui Generis Database
|
| 90 |
+
Rights, without regard to how the rights are labeled or
|
| 91 |
+
categorized. For purposes of this Public License, the rights
|
| 92 |
+
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
| 93 |
+
Rights.
|
| 94 |
+
d. Effective Technological Measures means those measures that, in the
|
| 95 |
+
absence of proper authority, may not be circumvented under laws
|
| 96 |
+
fulfilling obligations under Article 11 of the WIPO Copyright
|
| 97 |
+
Treaty adopted on December 20, 1996, and/or similar international
|
| 98 |
+
agreements.
|
| 99 |
+
|
| 100 |
+
e. Exceptions and Limitations means fair use, fair dealing, and/or
|
| 101 |
+
any other exception or limitation to Copyright and Similar Rights
|
| 102 |
+
that applies to Your use of the Licensed Material.
|
| 103 |
+
|
| 104 |
+
f. Licensed Material means the artistic or literary work, database,
|
| 105 |
+
or other material to which the Licensor applied this Public
|
| 106 |
+
License.
|
| 107 |
+
|
| 108 |
+
g. Licensed Rights means the rights granted to You subject to the
|
| 109 |
+
terms and conditions of this Public License, which are limited to
|
| 110 |
+
all Copyright and Similar Rights that apply to Your use of the
|
| 111 |
+
Licensed Material and that the Licensor has authority to license.
|
| 112 |
+
|
| 113 |
+
h. Licensor means the individual(s) or entity(ies) granting rights
|
| 114 |
+
under this Public License.
|
| 115 |
+
|
| 116 |
+
i. NonCommercial means not primarily intended for or directed towards
|
| 117 |
+
commercial advantage or monetary compensation. For purposes of
|
| 118 |
+
this Public License, the exchange of the Licensed Material for
|
| 119 |
+
other material subject to Copyright and Similar Rights by digital
|
| 120 |
+
file-sharing or similar means is NonCommercial provided there is
|
| 121 |
+
no payment of monetary compensation in connection with the
|
| 122 |
+
exchange.
|
| 123 |
+
|
| 124 |
+
j. Share means to provide material to the public by any means or
|
| 125 |
+
process that requires permission under the Licensed Rights, such
|
| 126 |
+
as reproduction, public display, public performance, distribution,
|
| 127 |
+
dissemination, communication, or importation, and to make material
|
| 128 |
+
available to the public including in ways that members of the
|
| 129 |
+
public may access the material from a place and at a time
|
| 130 |
+
individually chosen by them.
|
| 131 |
+
|
| 132 |
+
k. Sui Generis Database Rights means rights other than copyright
|
| 133 |
+
resulting from Directive 96/9/EC of the European Parliament and of
|
| 134 |
+
the Council of 11 March 1996 on the legal protection of databases,
|
| 135 |
+
as amended and/or succeeded, as well as other essentially
|
| 136 |
+
equivalent rights anywhere in the world.
|
| 137 |
+
|
| 138 |
+
l. You means the individual or entity exercising the Licensed Rights
|
| 139 |
+
under this Public License. Your has a corresponding meaning.
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
Section 2 -- Scope.
|
| 143 |
+
|
| 144 |
+
a. License grant.
|
| 145 |
+
|
| 146 |
+
1. Subject to the terms and conditions of this Public License,
|
| 147 |
+
the Licensor hereby grants You a worldwide, royalty-free,
|
| 148 |
+
non-sublicensable, non-exclusive, irrevocable license to
|
| 149 |
+
exercise the Licensed Rights in the Licensed Material to:
|
| 150 |
+
|
| 151 |
+
a. reproduce and Share the Licensed Material, in whole or
|
| 152 |
+
in part, for NonCommercial purposes only; and
|
| 153 |
+
|
| 154 |
+
b. produce, reproduce, and Share Adapted Material for
|
| 155 |
+
NonCommercial purposes only.
|
| 156 |
+
|
| 157 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where
|
| 158 |
+
Exceptions and Limitations apply to Your use, this Public
|
| 159 |
+
License does not apply, and You do not need to comply with
|
| 160 |
+
its terms and conditions.
|
| 161 |
+
|
| 162 |
+
3. Term. The term of this Public License is specified in Section
|
| 163 |
+
6(a).
|
| 164 |
+
|
| 165 |
+
4. Media and formats; technical modifications allowed. The
|
| 166 |
+
Licensor authorizes You to exercise the Licensed Rights in
|
| 167 |
+
all media and formats whether now known or hereafter created,
|
| 168 |
+
and to make technical modifications necessary to do so. The
|
| 169 |
+
Licensor waives and/or agrees not to assert any right or
|
| 170 |
+
authority to forbid You from making technical modifications
|
| 171 |
+
necessary to exercise the Licensed Rights, including
|
| 172 |
+
technical modifications necessary to circumvent Effective
|
| 173 |
+
Technological Measures. For purposes of this Public License,
|
| 174 |
+
simply making modifications authorized by this Section 2(a)
|
| 175 |
+
(4) never produces Adapted Material.
|
| 176 |
+
|
| 177 |
+
5. Downstream recipients.
|
| 178 |
+
|
| 179 |
+
a. Offer from the Licensor -- Licensed Material. Every
|
| 180 |
+
recipient of the Licensed Material automatically
|
| 181 |
+
receives an offer from the Licensor to exercise the
|
| 182 |
+
Licensed Rights under the terms and conditions of this
|
| 183 |
+
Public License.
|
| 184 |
+
|
| 185 |
+
b. No downstream restrictions. You may not offer or impose
|
| 186 |
+
any additional or different terms or conditions on, or
|
| 187 |
+
apply any Effective Technological Measures to, the
|
| 188 |
+
Licensed Material if doing so restricts exercise of the
|
| 189 |
+
Licensed Rights by any recipient of the Licensed
|
| 190 |
+
Material.
|
| 191 |
+
|
| 192 |
+
6. No endorsement. Nothing in this Public License constitutes or
|
| 193 |
+
may be construed as permission to assert or imply that You
|
| 194 |
+
are, or that Your use of the Licensed Material is, connected
|
| 195 |
+
with, or sponsored, endorsed, or granted official status by,
|
| 196 |
+
the Licensor or others designated to receive attribution as
|
| 197 |
+
provided in Section 3(a)(1)(A)(i).
|
| 198 |
+
|
| 199 |
+
b. Other rights.
|
| 200 |
+
|
| 201 |
+
1. Moral rights, such as the right of integrity, are not
|
| 202 |
+
licensed under this Public License, nor are publicity,
|
| 203 |
+
privacy, and/or other similar personality rights; however, to
|
| 204 |
+
the extent possible, the Licensor waives and/or agrees not to
|
| 205 |
+
assert any such rights held by the Licensor to the limited
|
| 206 |
+
extent necessary to allow You to exercise the Licensed
|
| 207 |
+
Rights, but not otherwise.
|
| 208 |
+
|
| 209 |
+
2. Patent and trademark rights are not licensed under this
|
| 210 |
+
Public License.
|
| 211 |
+
|
| 212 |
+
3. To the extent possible, the Licensor waives any right to
|
| 213 |
+
collect royalties from You for the exercise of the Licensed
|
| 214 |
+
Rights, whether directly or through a collecting society
|
| 215 |
+
under any voluntary or waivable statutory or compulsory
|
| 216 |
+
licensing scheme. In all other cases the Licensor expressly
|
| 217 |
+
reserves any right to collect such royalties, including when
|
| 218 |
+
the Licensed Material is used other than for NonCommercial
|
| 219 |
+
purposes.
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
Section 3 -- License Conditions.
|
| 223 |
+
|
| 224 |
+
Your exercise of the Licensed Rights is expressly made subject to the
|
| 225 |
+
following conditions.
|
| 226 |
+
|
| 227 |
+
a. Attribution.
|
| 228 |
+
|
| 229 |
+
1. If You Share the Licensed Material (including in modified
|
| 230 |
+
form), You must:
|
| 231 |
+
|
| 232 |
+
a. retain the following if it is supplied by the Licensor
|
| 233 |
+
with the Licensed Material:
|
| 234 |
+
|
| 235 |
+
i. identification of the creator(s) of the Licensed
|
| 236 |
+
Material and any others designated to receive
|
| 237 |
+
attribution, in any reasonable manner requested by
|
| 238 |
+
the Licensor (including by pseudonym if
|
| 239 |
+
designated);
|
| 240 |
+
|
| 241 |
+
ii. a copyright notice;
|
| 242 |
+
|
| 243 |
+
iii. a notice that refers to this Public License;
|
| 244 |
+
|
| 245 |
+
iv. a notice that refers to the disclaimer of
|
| 246 |
+
warranties;
|
| 247 |
+
|
| 248 |
+
v. a URI or hyperlink to the Licensed Material to the
|
| 249 |
+
extent reasonably practicable;
|
| 250 |
+
|
| 251 |
+
b. indicate if You modified the Licensed Material and
|
| 252 |
+
retain an indication of any previous modifications; and
|
| 253 |
+
|
| 254 |
+
c. indicate the Licensed Material is licensed under this
|
| 255 |
+
Public License, and include the text of, or the URI or
|
| 256 |
+
hyperlink to, this Public License.
|
| 257 |
+
|
| 258 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any
|
| 259 |
+
reasonable manner based on the medium, means, and context in
|
| 260 |
+
which You Share the Licensed Material. For example, it may be
|
| 261 |
+
reasonable to satisfy the conditions by providing a URI or
|
| 262 |
+
hyperlink to a resource that includes the required
|
| 263 |
+
information.
|
| 264 |
+
|
| 265 |
+
3. If requested by the Licensor, You must remove any of the
|
| 266 |
+
information required by Section 3(a)(1)(A) to the extent
|
| 267 |
+
reasonably practicable.
|
| 268 |
+
|
| 269 |
+
4. If You Share Adapted Material You produce, the Adapter's
|
| 270 |
+
License You apply must not prevent recipients of the Adapted
|
| 271 |
+
Material from complying with this Public License.
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
Section 4 -- Sui Generis Database Rights.
|
| 275 |
+
|
| 276 |
+
Where the Licensed Rights include Sui Generis Database Rights that
|
| 277 |
+
apply to Your use of the Licensed Material:
|
| 278 |
+
|
| 279 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
| 280 |
+
to extract, reuse, reproduce, and Share all or a substantial
|
| 281 |
+
portion of the contents of the database for NonCommercial purposes
|
| 282 |
+
only;
|
| 283 |
+
|
| 284 |
+
b. if You include all or a substantial portion of the database
|
| 285 |
+
contents in a database in which You have Sui Generis Database
|
| 286 |
+
Rights, then the database in which You have Sui Generis Database
|
| 287 |
+
Rights (but not its individual contents) is Adapted Material; and
|
| 288 |
+
|
| 289 |
+
c. You must comply with the conditions in Section 3(a) if You Share
|
| 290 |
+
all or a substantial portion of the contents of the database.
|
| 291 |
+
|
| 292 |
+
For the avoidance of doubt, this Section 4 supplements and does not
|
| 293 |
+
replace Your obligations under this Public License where the Licensed
|
| 294 |
+
Rights include other Copyright and Similar Rights.
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
| 298 |
+
|
| 299 |
+
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
| 300 |
+
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
| 301 |
+
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
| 302 |
+
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
| 303 |
+
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
| 304 |
+
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
| 305 |
+
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
| 306 |
+
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
| 307 |
+
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
| 308 |
+
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
| 309 |
+
|
| 310 |
+
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
| 311 |
+
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
| 312 |
+
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
| 313 |
+
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
| 314 |
+
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
| 315 |
+
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
| 316 |
+
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
| 317 |
+
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
| 318 |
+
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
| 319 |
+
|
| 320 |
+
c. The disclaimer of warranties and limitation of liability provided
|
| 321 |
+
above shall be interpreted in a manner that, to the extent
|
| 322 |
+
possible, most closely approximates an absolute disclaimer and
|
| 323 |
+
waiver of all liability.
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
Section 6 -- Term and Termination.
|
| 327 |
+
|
| 328 |
+
a. This Public License applies for the term of the Copyright and
|
| 329 |
+
Similar Rights licensed here. However, if You fail to comply with
|
| 330 |
+
this Public License, then Your rights under this Public License
|
| 331 |
+
terminate automatically.
|
| 332 |
+
|
| 333 |
+
b. Where Your right to use the Licensed Material has terminated under
|
| 334 |
+
Section 6(a), it reinstates:
|
| 335 |
+
|
| 336 |
+
1. automatically as of the date the violation is cured, provided
|
| 337 |
+
it is cured within 30 days of Your discovery of the
|
| 338 |
+
violation; or
|
| 339 |
+
|
| 340 |
+
2. upon express reinstatement by the Licensor.
|
| 341 |
+
|
| 342 |
+
For the avoidance of doubt, this Section 6(b) does not affect any
|
| 343 |
+
right the Licensor may have to seek remedies for Your violations
|
| 344 |
+
of this Public License.
|
| 345 |
+
|
| 346 |
+
c. For the avoidance of doubt, the Licensor may also offer the
|
| 347 |
+
Licensed Material under separate terms or conditions or stop
|
| 348 |
+
distributing the Licensed Material at any time; however, doing so
|
| 349 |
+
will not terminate this Public License.
|
| 350 |
+
|
| 351 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
| 352 |
+
License.
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
Section 7 -- Other Terms and Conditions.
|
| 356 |
+
|
| 357 |
+
a. The Licensor shall not be bound by any additional or different
|
| 358 |
+
terms or conditions communicated by You unless expressly agreed.
|
| 359 |
+
|
| 360 |
+
b. Any arrangements, understandings, or agreements regarding the
|
| 361 |
+
Licensed Material not stated herein are separate from and
|
| 362 |
+
independent of the terms and conditions of this Public License.
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
Section 8 -- Interpretation.
|
| 366 |
+
|
| 367 |
+
a. For the avoidance of doubt, this Public License does not, and
|
| 368 |
+
shall not be interpreted to, reduce, limit, restrict, or impose
|
| 369 |
+
conditions on any use of the Licensed Material that could lawfully
|
| 370 |
+
be made without permission under this Public License.
|
| 371 |
+
|
| 372 |
+
b. To the extent possible, if any provision of this Public License is
|
| 373 |
+
deemed unenforceable, it shall be automatically reformed to the
|
| 374 |
+
minimum extent necessary to make it enforceable. If the provision
|
| 375 |
+
cannot be reformed, it shall be severed from this Public License
|
| 376 |
+
without affecting the enforceability of the remaining terms and
|
| 377 |
+
conditions.
|
| 378 |
+
|
| 379 |
+
c. No term or condition of this Public License will be waived and no
|
| 380 |
+
failure to comply consented to unless expressly agreed to by the
|
| 381 |
+
Licensor.
|
| 382 |
+
|
| 383 |
+
d. Nothing in this Public License constitutes or may be interpreted
|
| 384 |
+
as a limitation upon, or waiver of, any privileges and immunities
|
| 385 |
+
that apply to the Licensor or You, including from the legal
|
| 386 |
+
processes of any jurisdiction or authority.
|
| 387 |
+
|
| 388 |
+
=======================================================================
|
| 389 |
+
|
| 390 |
+
Creative Commons is not a party to its public
|
| 391 |
+
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
| 392 |
+
its public licenses to material it publishes and in those instances
|
| 393 |
+
will be considered the “Licensor.” The text of the Creative Commons
|
| 394 |
+
public licenses is dedicated to the public domain under the CC0 Public
|
| 395 |
+
Domain Dedication. Except for the limited purpose of indicating that
|
| 396 |
+
material is shared under a Creative Commons public license or as
|
| 397 |
+
otherwise permitted by the Creative Commons policies published at
|
| 398 |
+
creativecommons.org/policies, Creative Commons does not authorize the
|
| 399 |
+
use of the trademark "Creative Commons" or any other trademark or logo
|
| 400 |
+
of Creative Commons without its prior written consent including,
|
| 401 |
+
without limitation, in connection with any unauthorized modifications
|
| 402 |
+
to any of its public licenses or any other arrangements,
|
| 403 |
+
understandings, or agreements concerning use of licensed material. For
|
| 404 |
+
the avoidance of doubt, this paragraph does not form part of the
|
| 405 |
+
public licenses.
|
| 406 |
+
|
| 407 |
+
Creative Commons may be contacted at creativecommons.org.
|
README.md
CHANGED
|
@@ -1,12 +1,115 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MultiTaskTox – LightGBM Fingerprint Classifier for Tox21
|
| 2 |
+
|
| 3 |
+
MultiTaskTox is a two-stage Gradient Boosting workflow purpose-built for the [Tox21](https://huggingface.co/datasets/ml-jku/tox21) benchmark. It ingests molecular SMILES strings, converts them into high-dimensional fingerprints (ECFP or MAP4), and trains a set of LightGBM classifiers that leverage cross-task signal to improve toxicity prediction across all 12 Tox21 targets.
|
| 4 |
+
|
| 5 |
+
## Why MultiTaskTox?
|
| 6 |
+
|
| 7 |
+
- **Deterministic preprocessing** – every SMILES string is standardized through RDKit before fingerprint generation, ensuring training and inference behave identically.
|
| 8 |
+
- **Optuna-tuned per-task boosters** – each toxicity endpoint receives its own LightGBM classifier, tuned directly on the provided train/validation splits.
|
| 9 |
+
- **Multitask enhancement** – stage two augments the fingerprint vector with the predictions of the other tasks, capturing label correlations without building a fully joint model.
|
| 10 |
+
- **Leaderboard-ready interface** – `train.py` produces checkpoints and metadata under `checkpoints/`, while `predict.py` exposes the required `predict(smiles_list)` signature.
|
| 11 |
+
|
| 12 |
+
## Installation
|
| 13 |
+
|
| 14 |
+
```bash
|
| 15 |
+
git clone https://huggingface.co/spaces/ml-jku/tox21_gin_classifier
|
| 16 |
+
cd tox21_gin_classifier
|
| 17 |
+
python -m venv .venv && source .venv/bin/activate
|
| 18 |
+
pip install --upgrade pip
|
| 19 |
+
pip install -r requirements.txt
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
The requirements include RDKit, LightGBM, Optuna, and the MAP4 fingerprint package so you can switch feature types via the config.
|
| 23 |
+
|
| 24 |
+
## Training
|
| 25 |
+
|
| 26 |
+
1. Create a `.env` file (all Hugging Face Spaces support secrets) with your dataset token:
|
| 27 |
+
```
|
| 28 |
+
TOKEN=hf_xxx
|
| 29 |
+
```
|
| 30 |
+
2. Adjust `config/config.json` if needed (fingerprint type, Optuna trial count, etc.).
|
| 31 |
+
3. Run:
|
| 32 |
+
```bash
|
| 33 |
+
python train.py
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
### What `train.py` does
|
| 37 |
+
|
| 38 |
+
1. Loads the predefined `train` and `validation` splits from the Tox21 dataset.
|
| 39 |
+
2. Standardizes SMILES and builds fingerprints using `src/features.py`.
|
| 40 |
+
3. For each target:
|
| 41 |
+
- Runs Optuna to find the best LightGBM hyperparameters using the validation split as the evaluation set.
|
| 42 |
+
- Fits the classifier (`stage1`) and stores the model as `checkpoints/stage1/<target>.pkl`.
|
| 43 |
+
4. Generates prediction matrices for both splits.
|
| 44 |
+
5. If multitask mode is enabled (`config["multitask"]["enabled"]`), creates augmented features (fingerprint + other-task predictions) and trains stage-two boosters saved under `checkpoints/stage2/`.
|
| 45 |
+
6. Writes metrics (`metrics_stage1.json`, `metrics_stage2.json`) and a manifest (`training_manifest.json`) describing the experiment.
|
| 46 |
+
|
| 47 |
+
## Inference
|
| 48 |
+
|
| 49 |
+
`predict.py` exposes:
|
| 50 |
+
|
| 51 |
+
```python
|
| 52 |
+
from predict import predict
|
| 53 |
+
|
| 54 |
+
smiles = ["CCO", "c1ccccc1", "CC(=O)O"]
|
| 55 |
+
results = predict(smiles)
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
The function:
|
| 59 |
+
1. Loads the training manifest to know which fingerprint type and checkpoints to use.
|
| 60 |
+
2. Standardizes and fingerprints the SMILES on the fly.
|
| 61 |
+
3. Runs stage-one LightGBM classifiers to obtain probabilistic predictions.
|
| 62 |
+
4. If stage-two models exist, augments the features with cross-task predictions and runs the multitask models.
|
| 63 |
+
5. Returns `{smiles: {target_name: probability}}` with values in `[0, 1]`. Invalid SMILES fall back to `0.5`.
|
| 64 |
+
|
| 65 |
+
## Configuration Overview (`config/config.json`)
|
| 66 |
+
|
| 67 |
+
```json
|
| 68 |
+
{
|
| 69 |
+
"seed": 42,
|
| 70 |
+
"dataset": {"name": "ml-jku/tox21"},
|
| 71 |
+
"features": {
|
| 72 |
+
"type": "ecfp",
|
| 73 |
+
"radius": 2,
|
| 74 |
+
"n_bits": 1024,
|
| 75 |
+
"map4_dim": 1024,
|
| 76 |
+
"cache_dir": "./checkpoints/cache"
|
| 77 |
+
},
|
| 78 |
+
"training": {
|
| 79 |
+
"optuna_trials": 40,
|
| 80 |
+
"boosting_rounds": 1500,
|
| 81 |
+
"early_stopping_rounds": 100,
|
| 82 |
+
"lightgbm_params": {
|
| 83 |
+
"objective": "binary",
|
| 84 |
+
"metric": "auc",
|
| 85 |
+
"verbosity": -1
|
| 86 |
+
}
|
| 87 |
+
},
|
| 88 |
+
"multitask": {"enabled": true},
|
| 89 |
+
"output": {"checkpoint_dir": "./checkpoints"}
|
| 90 |
+
}
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
- Switch `features.type` to `"map4"` to use MAP4 fingerprints (installed by default).
|
| 94 |
+
- Disable multitask behavior by setting `"multitask": {"enabled": false}`.
|
| 95 |
+
- Increase `optuna_trials` for a more exhaustive search if compute allows.
|
| 96 |
+
|
| 97 |
+
## Repository Layout
|
| 98 |
+
|
| 99 |
+
- `train.py` – orchestrates the full training workflow (feature generation, Optuna tuning, stage-one and stage-two models).
|
| 100 |
+
- `predict.py` – leaderboard-friendly inference function that loads the checkpoints generated by `train.py`.
|
| 101 |
+
- `src/preprocess.py` – dataset loading and SMILES standardization helpers.
|
| 102 |
+
- `src/features.py` – fingerprint computation with disk caching.
|
| 103 |
+
- `src/lightgbm_trainer.py` – LightGBM + Optuna utilities for stage-one training.
|
| 104 |
+
- `src/stage_two.py` – multitask feature augmentation and model training.
|
| 105 |
+
- `src/constants.py`, `src/seed.py` – shared utilities.
|
| 106 |
+
- `docs/proposed_lightgbm_framework.md` – detailed design notes for the workflow.
|
| 107 |
+
- `checkpoints/` – default output directory containing models, metrics, caches, and the training manifest used at inference time.
|
| 108 |
+
|
| 109 |
+
## Tips
|
| 110 |
+
|
| 111 |
+
- Training relies on the `TOKEN` environment variable to access the Tox21 dataset on Hugging Face. Locally you can omit it if the dataset is public for your account.
|
| 112 |
+
- MAP4 fingerprints are more expensive to compute; enable the cache directory to avoid recomputation across runs.
|
| 113 |
+
- Use the saved metrics files to compare stage-one vs. stage-two AUCs and to trace which configuration produced a set of checkpoints.
|
| 114 |
+
|
| 115 |
+
Happy modeling! If you extend MultiTaskTox (new fingerprints, alternative learners, etc.), keep the `predict(smiles)` contract intact so your Space remains leaderboard compatible.
|
app.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This is the main entry point for the FastAPI application.
|
| 3 |
+
The app handles the request to predict toxicity for a list of SMILES strings.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
# ---------------------------------------------------------------------------------------
|
| 7 |
+
# Dependencies and global variable definition
|
| 8 |
+
import os
|
| 9 |
+
from typing import List, Dict, Optional
|
| 10 |
+
from fastapi import FastAPI, Header, HTTPException
|
| 11 |
+
from pydantic import BaseModel, Field
|
| 12 |
+
|
| 13 |
+
from predict import predict as predict_func
|
| 14 |
+
|
| 15 |
+
API_KEY = os.getenv("API_KEY") # set via Space Secrets
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# ---------------------------------------------------------------------------------------
|
| 19 |
+
class Request(BaseModel):
|
| 20 |
+
smiles: List[str] = Field(min_items=1, max_items=1000)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class Response(BaseModel):
|
| 24 |
+
predictions: dict
|
| 25 |
+
model_info: Dict[str, str] = {}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
app = FastAPI(title="toxicity-api")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@app.get("/")
|
| 32 |
+
def root():
|
| 33 |
+
return {
|
| 34 |
+
"message": "Toxicity Prediction API",
|
| 35 |
+
"endpoints": {
|
| 36 |
+
"/metadata": "GET - API metadata and capabilities",
|
| 37 |
+
"/healthz": "GET - Health check",
|
| 38 |
+
"/predict": "POST - Predict toxicity for SMILES",
|
| 39 |
+
},
|
| 40 |
+
"usage": "Send POST to /predict with {'smiles': ['your_smiles_here']}",
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@app.get("/metadata")
|
| 45 |
+
def metadata():
|
| 46 |
+
return {
|
| 47 |
+
"name": "Tox21 GIN classifier",
|
| 48 |
+
"version": "1.0.0",
|
| 49 |
+
"tox_endpoints": [
|
| 50 |
+
"NR-AR",
|
| 51 |
+
"NR-AR-LBD",
|
| 52 |
+
"NR-AhR",
|
| 53 |
+
"NR-Aromatase",
|
| 54 |
+
"NR-ER",
|
| 55 |
+
"NR-ER-LBD",
|
| 56 |
+
"NR-PPAR-gamma",
|
| 57 |
+
"SR-ARE",
|
| 58 |
+
"SR-ATAD5",
|
| 59 |
+
"SR-HSE",
|
| 60 |
+
"SR-MMP",
|
| 61 |
+
"SR-p53",
|
| 62 |
+
],
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@app.get("/healthz")
|
| 67 |
+
def healthz():
|
| 68 |
+
return {"ok": True}
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@app.post("/predict", response_model=Response)
|
| 72 |
+
def predict(request: Request):
|
| 73 |
+
predictions = predict_func(request.smiles)
|
| 74 |
+
return {
|
| 75 |
+
"predictions": predictions,
|
| 76 |
+
"model_info": {"name": "MultiTaskTox", "version": "0.0.1"},
|
| 77 |
+
}
|
config/config.json
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"seed": 42,
|
| 3 |
+
"dataset": {
|
| 4 |
+
"name": "ml-jku/tox21"
|
| 5 |
+
},
|
| 6 |
+
"features": {
|
| 7 |
+
"type": "ecfp",
|
| 8 |
+
"radius": 2,
|
| 9 |
+
"n_bits": 1024,
|
| 10 |
+
"use_counts": false,
|
| 11 |
+
"map4_dim": 1024,
|
| 12 |
+
"cache_dir": "./checkpoints/cache"
|
| 13 |
+
},
|
| 14 |
+
"training": {
|
| 15 |
+
"optuna_trials": 40,
|
| 16 |
+
"boosting_rounds": 1500,
|
| 17 |
+
"early_stopping_rounds": 100,
|
| 18 |
+
"lightgbm_params": {
|
| 19 |
+
"objective": "binary",
|
| 20 |
+
"metric": "auc",
|
| 21 |
+
"verbosity": -1
|
| 22 |
+
}
|
| 23 |
+
},
|
| 24 |
+
"multitask": {
|
| 25 |
+
"enabled": true,
|
| 26 |
+
"prediction_source": "oof"
|
| 27 |
+
},
|
| 28 |
+
"output": {
|
| 29 |
+
"checkpoint_dir": "./checkpoints"
|
| 30 |
+
}
|
| 31 |
+
}
|
docs/proposed_lightgbm_framework.md
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LightGBM-Based Multitask Workflow for Tox21
|
| 2 |
+
|
| 3 |
+
This document proposes a stepwise plan to replace the current GIN baseline (`train.py`, `predict.py`, `src/`) with a Gradient Boosting pipeline that remains compatible with the leaderboard I/O contract. Each phase can be validated independently before moving to the next, ensuring we have working training and inference artifacts at all times.
|
| 4 |
+
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## 0. Repository Integration Checklist
|
| 8 |
+
- **Entry-points stay the same.** `train.py` must continue to train from `config/config.json` and drop an inference-ready artifact into `checkpoints/`. `predict.py` must keep the `predict(smiles_list)` signature and return the nested `{smiles: {target: score}}` mapping.
|
| 9 |
+
- **New modules.** Introduce `src/features.py` (fingerprints & caching), `src/lightgbm_trainer.py` (shared utilities for training/evaluation), and `src/stage_two.py` (cross-task augmentation logic). Keep `src/preprocess.py` for SMILES standardization + RDKit `Mol` construction so inference stays aligned with training.
|
| 10 |
+
- **Dependencies.** Add `lightgbm`, `optuna`, `rdkit-pypi`, and optionally `map4` or `map4` reference code to `requirements.txt`. Verify any native dependencies are supported by the Spaces environment.
|
| 11 |
+
- **Artifacts.** Store per-task boosters as `checkpoints/stage1_{task}.txt` and `checkpoints/stage2_{task}.txt` (LightGBM text dumps). Derived predictions (e.g., stage-1 OOF matrices) should live under `checkpoints/cache/` or `/tmp` during training, but inference must rely only on checkpoint files generated by `train.py`.
|
| 12 |
+
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
## 1. Phase 1 — Baseline LightGBM with Optuna
|
| 16 |
+
|
| 17 |
+
### 1.1 Data handling
|
| 18 |
+
1. Load the Hugging Face dataset inside `train.py` exactly as today (`load_dataset("ml-jku/tox21", token=TOKEN)`).
|
| 19 |
+
2. Keep the same per-split segmentation (train/validation/test) to remain comparable with the GIN baseline.
|
| 20 |
+
3. Convert SMILES strings to RDKit `Mol` objects using the existing cleaners in `src/preprocess.py`. For the baseline, we can featurize molecules with a minimal descriptor set (e.g., RDKit physicochemical descriptors) while fingerprints are being implemented.
|
| 21 |
+
|
| 22 |
+
### 1.2 Baseline features
|
| 23 |
+
Use easily-computed descriptors such as:
|
| 24 |
+
- Molecular weight, logP, TPSA, number of H-bond donors/acceptors, rotatable bonds, aromatic proportion, etc.
|
| 25 |
+
- Concatenate one-hot encodings for atom count bins (C, N, O, halogens).
|
| 26 |
+
This gives a quick tabular vector per SMILES while fingerprint work is in progress.
|
| 27 |
+
|
| 28 |
+
### 1.3 Training objective
|
| 29 |
+
- **Task granularity:** Train one LightGBM binary classifier per Tox21 task (12 total). Targets remain the provided binary toxicity labels.
|
| 30 |
+
- **Metric:** ROC-AUC per task, with macro-average for reporting (mirrors leaderboard metric).
|
| 31 |
+
- **Data split:** For each task, drop rows with missing labels and perform K-fold CV (e.g., 5 folds) inside Optuna to make best use of labeled data.
|
| 32 |
+
|
| 33 |
+
### 1.4 Optuna search space
|
| 34 |
+
Within `src/lightgbm_trainer.py`, expose an `objective(trial, task_name)` that:
|
| 35 |
+
1. Samples:
|
| 36 |
+
- `learning_rate ∈ [1e-3, 0.2]` (log scale)
|
| 37 |
+
- `num_leaves ∈ [16, 256]`
|
| 38 |
+
- `max_depth ∈ [-1, 12]`
|
| 39 |
+
- `min_data_in_leaf ∈ [10, 200]`
|
| 40 |
+
- `feature_fraction ∈ [0.5, 1.0]`
|
| 41 |
+
- `bagging_fraction ∈ [0.5, 1.0]` with `bagging_freq ∈ [1, 10]`
|
| 42 |
+
- `lambda_l1`, `lambda_l2` (10^-8 to 10^1)
|
| 43 |
+
2. Trains the LightGBM model on each CV split and averages ROC-AUC.
|
| 44 |
+
3. Returns the negative mean ROC-AUC so Optuna can minimize the objective.
|
| 45 |
+
|
| 46 |
+
Persist the best hyperparameters per task into the config (or a JSON artifact) so `predict.py` can instantiate the booster with exact values. When data volume is small, Optuna’s `Study` can share the same random seed for reproducibility (`src/seed.py` can be reused).
|
| 47 |
+
|
| 48 |
+
### 1.5 Deliverables for Phase 1
|
| 49 |
+
- Updated `train.py` calling into `src/lightgbm_trainer.train_single_task(task_name, features, labels, config)`.
|
| 50 |
+
- `checkpoints/stage1_{task}.txt` boosters (even though they are “stage 1”, they form the baseline deliverable).
|
| 51 |
+
- Validation report (per-task ROC-AUC) saved to `checkpoints/metrics_stage1.json`.
|
| 52 |
+
- `predict.py` loads each per-task LightGBM model, computes baseline descriptors on-the-fly, and returns predictions.
|
| 53 |
+
|
| 54 |
+
---
|
| 55 |
+
|
| 56 |
+
## 2. Phase 2 — Fingerprint-Based Representations
|
| 57 |
+
|
| 58 |
+
### 2.1 Feature computation
|
| 59 |
+
Implement `src/features.py` with methods:
|
| 60 |
+
- `compute_ecfp(mol, radius=2, n_bits=1024)` using `GetMorganFingerprintAsBitVect`.
|
| 61 |
+
- `compute_map4(mol)` via MAP4 codebase (counts hashed patterns). Because MAP4 is computationally heavier, cache features to disk (e.g., `cache/fingerprints_{split}.npz`).
|
| 62 |
+
- `fingerprint_pipeline(smiles_list, fingerprint_type)` that accepts sanitized SMILES, constructs `Mol` objects, and returns a dense `np.ndarray`.
|
| 63 |
+
|
| 64 |
+
### 2.2 Integration
|
| 65 |
+
- Update `train.py` to choose the fingerprint type from config (e.g., `config["features"]["type"] = "ecfp"`).
|
| 66 |
+
- Align `predict.py` to call the same fingerprint builder on incoming SMILES.
|
| 67 |
+
- Maintain metadata describing fingerprint dimensionality and type in a manifest (e.g., `checkpoints/features.json`) so inference knows how to parse the stored LightGBM feature order.
|
| 68 |
+
|
| 69 |
+
### 2.3 Training flow
|
| 70 |
+
Apart from the enriched features, Phase 2 reuses the Phase 1 training loop. If resource constraints exist, we can:
|
| 71 |
+
- Run Optuna once on a representative task (e.g., NR-AhR) and reuse its best hyperparameters for all tasks; or
|
| 72 |
+
- Run Optuna briefly per task (e.g., 30 trials) and share results.
|
| 73 |
+
|
| 74 |
+
### 2.4 Deliverables
|
| 75 |
+
- Fingerprint cache builders + unit tests (small set of SMILES).
|
| 76 |
+
- Configurable training/inference that toggles between baseline descriptors and fingerprint vectors.
|
| 77 |
+
- Updated metrics comparing descriptors vs. ECFP vs. MAP4.
|
| 78 |
+
|
| 79 |
+
---
|
| 80 |
+
|
| 81 |
+
## 3. Phase 3 — Cross-Task Label Augmentation
|
| 82 |
+
|
| 83 |
+
### 3.1 Motivation
|
| 84 |
+
By incorporating predictions from other tasks, we expose LightGBM to shared toxicity patterns without building a fully joint model. This is especially valuable for underrepresented tasks where correlated labels provide additional signal.
|
| 85 |
+
|
| 86 |
+
### 3.2 Feature construction
|
| 87 |
+
Given `T = 12` tasks and fingerprint dimension `D`, the augmented features for task `k` are:
|
| 88 |
+
```
|
| 89 |
+
X_k = [fingerprint_vector (D dims), ŷ_1, …, ŷ_{k-1}, ŷ_{k+1}, …, ŷ_T]
|
| 90 |
+
```
|
| 91 |
+
where `ŷ_t` are the stage-1 predictions for task `t` on the same molecule. Use floats instead of hard labels to preserve uncertainty.
|
| 92 |
+
|
| 93 |
+
### 3.3 Implementation details
|
| 94 |
+
1. **Collect stage-1 predictions.**
|
| 95 |
+
- After Phase 2 training, run inference with each stage-1 model on every molecule in train/val/test splits.
|
| 96 |
+
- Store the `N × T` prediction matrix in `checkpoints/stage1_predictions_{split}.npz`.
|
| 97 |
+
2. **Align missing data.**
|
| 98 |
+
- If task `t` lacks a label for a molecule, mask it during stage-1 training but still compute predictions for other tasks so the feature matrix stays dense.
|
| 99 |
+
3. **Data leakage prevention.**
|
| 100 |
+
- During training, use out-of-fold predictions (OOF) for the stage-1 features so models do not see their own ground-truth labels through the augmented vector.
|
| 101 |
+
- Implementation: For each fold, train stage-1 LightGBM on K-1 folds, predict on the held-out fold, and concatenate predictions.
|
| 102 |
+
4. **Config surface.**
|
| 103 |
+
- `config["multitask"]["use_stage1_predictions"] = true/false`
|
| 104 |
+
- `config["multitask"]["prediction_source"] = "oof" | "full_train"` to switch between strict OOF features and simpler (but leakier) full-train predictions for debugging.
|
| 105 |
+
|
| 106 |
+
### 3.4 Training
|
| 107 |
+
Once augmented features are ready, rerun the single-task LightGBM training per target (`stage2`). Hyperparameter search can be narrower because fingerprints already provide a strong baseline; focus on `num_leaves`, `feature_fraction`, and regularization strength.
|
| 108 |
+
|
| 109 |
+
### 3.5 Deliverables
|
| 110 |
+
- Scripts that generate OOF prediction matrices.
|
| 111 |
+
- Updated `train.py` orchestration:
|
| 112 |
+
1. Train Stage 1 models.
|
| 113 |
+
2. Materialize cross-task prediction cache.
|
| 114 |
+
3. Train Stage 2 models from augmented features.
|
| 115 |
+
- Metrics comparing Stage 1 vs. Stage 2 per task.
|
| 116 |
+
|
| 117 |
+
---
|
| 118 |
+
|
| 119 |
+
## 4. Phase 4 — Two-Stage Training & Inference
|
| 120 |
+
|
| 121 |
+
### 4.1 Training orchestration
|
| 122 |
+
Pseudo-flow for `train.py`:
|
| 123 |
+
|
| 124 |
+
```python
|
| 125 |
+
def train(config):
|
| 126 |
+
ds = load_dataset(...)
|
| 127 |
+
mols = preprocess.standardize(ds["train"]["smiles"])
|
| 128 |
+
fp_cache = features.fingerprint_pipeline(mols, config["features"])
|
| 129 |
+
|
| 130 |
+
stage1 = StageOneTrainer(config)
|
| 131 |
+
stage1.train_all_tasks(fp_cache, labels, splits)
|
| 132 |
+
stage1.save_models("checkpoints/stage1_*.txt")
|
| 133 |
+
|
| 134 |
+
pred_cache = stage1.generate_predictions(fp_cache, splits, use_oof=True)
|
| 135 |
+
|
| 136 |
+
stage2 = StageTwoTrainer(config)
|
| 137 |
+
stage2.train_all_tasks(fp_cache, pred_cache, labels)
|
| 138 |
+
stage2.save_models("checkpoints/stage2_*.txt")
|
| 139 |
+
|
| 140 |
+
dump_metrics(stage1.metrics, stage2.metrics)
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
### 4.2 Inference pipeline (`predict.py`)
|
| 144 |
+
1. **Fingerprint computation:** identical to training (deterministic sanitization).
|
| 145 |
+
2. **Stage-1 pass:** Load every `stage1_{task}.txt`, predict on the incoming SMILES batch, and collect predictions.
|
| 146 |
+
3. **Stage-2 pass:** For each task `k`, build `[fingerprint, predicted_labels_except_k]` on-the-fly and evaluate the corresponding stage-2 booster.
|
| 147 |
+
4. **Output:** Return the stage-2 predictions for leaderboard submission. Optionally include stage-1 scores in the response if needed for debugging (but the official output should stick to stage-2 values).
|
| 148 |
+
|
| 149 |
+
### 4.3 Failure modes & mitigations
|
| 150 |
+
- **Unrecognized SMILES:** fall back to zeros or 0.5 predictions like the current baseline but log warnings so we can monitor failure rates.
|
| 151 |
+
- **Missing checkpoint:** raise an informative exception instructing users to rerun `train.py`.
|
| 152 |
+
- **Performance drift:** store SHA or timestamp metadata with checkpoints to trace which training configuration produced a given model.
|
| 153 |
+
|
| 154 |
+
---
|
| 155 |
+
|
| 156 |
+
## 5. Configuration & Experiment Tracking
|
| 157 |
+
Proposed structure for `config/config.json`:
|
| 158 |
+
|
| 159 |
+
```json
|
| 160 |
+
{
|
| 161 |
+
"seed": 42,
|
| 162 |
+
"features": {
|
| 163 |
+
"type": "ecfp",
|
| 164 |
+
"radius": 2,
|
| 165 |
+
"n_bits": 1024,
|
| 166 |
+
"use_counts": false
|
| 167 |
+
},
|
| 168 |
+
"training": {
|
| 169 |
+
"n_folds": 5,
|
| 170 |
+
"n_optuna_trials": 50,
|
| 171 |
+
"lightgbm_params": {
|
| 172 |
+
"objective": "binary",
|
| 173 |
+
"metric": "auc",
|
| 174 |
+
"verbosity": -1
|
| 175 |
+
}
|
| 176 |
+
},
|
| 177 |
+
"multitask": {
|
| 178 |
+
"enabled": true,
|
| 179 |
+
"use_stage1_predictions": true,
|
| 180 |
+
"prediction_source": "oof"
|
| 181 |
+
}
|
| 182 |
+
}
|
| 183 |
+
```
|
| 184 |
+
|
| 185 |
+
Track experiment results in `checkpoints/experiments.csv` with columns `[timestamp, fingerprint, stage, task, auc, params_hash]`.
|
| 186 |
+
|
| 187 |
+
---
|
| 188 |
+
|
| 189 |
+
## 6. Testing & Validation
|
| 190 |
+
- **Unit tests:** Ensure fingerprint builders reproduce known vectors (compare with RDKit reference) and that cross-task feature assembly drops the correct task column.
|
| 191 |
+
- **Integration tests:** Small toy dataset (3 tasks, <50 samples) to run the full Stage1→Stage2 pipeline quickly. Assert shapes of caches and that inference matches training predictions.
|
| 192 |
+
- **Performance tracking:** Plot per-task ROC-AUC improvements by phase to confirm each enhancement adds value.
|
| 193 |
+
|
| 194 |
+
---
|
| 195 |
+
|
| 196 |
+
## 7. Suggested Implementation Milestones
|
| 197 |
+
1. **M1:** Skeleton LightGBM trainer + Optuna integration (Phase 1). ✓
|
| 198 |
+
2. **M2:** Fingerprint computation module with caching + updated training/inference (Phase 2).
|
| 199 |
+
3. **M3:** Stage-1 prediction cache + feature augmentation (Phase 3).
|
| 200 |
+
4. **M4:** End-to-end Stage1→Stage2 orchestration, packaging of checkpoints, and inference updates (Phase 4).
|
| 201 |
+
5. **M5:** Documentation + automated tests to guard against regressions.
|
| 202 |
+
|
| 203 |
+
This phased roadmap keeps the leaderboard interface intact while progressively increasing the modeling capacity from simple descriptors to multitask-enhanced fingerprints.
|
predict.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Dict, List
|
| 7 |
+
|
| 8 |
+
import joblib
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
from src.constants import TARGET_NAMES
|
| 12 |
+
from src.features import FingerprintFeaturizer
|
| 13 |
+
from src.seed import set_seed
|
| 14 |
+
|
| 15 |
+
BASE_PREDICTION = 0.5
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@lru_cache(maxsize=1)
|
| 19 |
+
def _load_manifest() -> Dict:
|
| 20 |
+
manifest_path = Path("./checkpoints/training_manifest.json")
|
| 21 |
+
if not manifest_path.exists():
|
| 22 |
+
raise FileNotFoundError("Missing checkpoints/training_manifest.json. Run train.py first.")
|
| 23 |
+
with manifest_path.open("r", encoding="utf-8") as f:
|
| 24 |
+
manifest = json.load(f)
|
| 25 |
+
return manifest
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@lru_cache(maxsize=2)
|
| 29 |
+
def _load_stage_models(stage: str):
|
| 30 |
+
manifest = _load_manifest()
|
| 31 |
+
stage_info = manifest.get(stage, {})
|
| 32 |
+
model_dir = stage_info.get("model_dir")
|
| 33 |
+
if not model_dir:
|
| 34 |
+
return {}
|
| 35 |
+
model_path = Path(model_dir)
|
| 36 |
+
models = {}
|
| 37 |
+
for target in manifest.get("target_names", TARGET_NAMES):
|
| 38 |
+
model_file = model_path / f"{target}.pkl"
|
| 39 |
+
if model_file.exists():
|
| 40 |
+
models[target] = joblib.load(model_file)
|
| 41 |
+
return models
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _compute_stage1_predictions(features: np.ndarray, target_names: List[str]) -> np.ndarray:
|
| 45 |
+
"""Return predictions for the valid molecules from stage-1 models."""
|
| 46 |
+
stage1_models = _load_stage_models("stage1")
|
| 47 |
+
if features.shape[0] == 0:
|
| 48 |
+
return np.zeros((0, len(target_names)), dtype=np.float32)
|
| 49 |
+
|
| 50 |
+
predictions = np.full((features.shape[0], len(target_names)), BASE_PREDICTION, dtype=np.float32)
|
| 51 |
+
for idx, target in enumerate(target_names):
|
| 52 |
+
booster = stage1_models.get(target)
|
| 53 |
+
if booster is None:
|
| 54 |
+
continue
|
| 55 |
+
best_iter = getattr(booster, "best_iteration_", None)
|
| 56 |
+
kwargs = {"num_iteration": best_iter} if best_iter is not None else {}
|
| 57 |
+
preds = booster.predict_proba(features, **kwargs)[:, 1]
|
| 58 |
+
predictions[:, idx] = preds
|
| 59 |
+
return predictions
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _compute_stage2_predictions(
|
| 63 |
+
base_features: np.ndarray,
|
| 64 |
+
stage1_preds: np.ndarray,
|
| 65 |
+
target_names: List[str],
|
| 66 |
+
) -> np.ndarray:
|
| 67 |
+
stage2_models = _load_stage_models("stage2")
|
| 68 |
+
if not stage2_models:
|
| 69 |
+
return stage1_preds
|
| 70 |
+
|
| 71 |
+
n_samples = base_features.shape[0]
|
| 72 |
+
results = np.full((n_samples, len(target_names)), BASE_PREDICTION, dtype=np.float32)
|
| 73 |
+
for idx, target in enumerate(target_names):
|
| 74 |
+
model = stage2_models.get(target)
|
| 75 |
+
if model is None:
|
| 76 |
+
results[:, idx] = stage1_preds[:, idx]
|
| 77 |
+
continue
|
| 78 |
+
augmented = np.concatenate(
|
| 79 |
+
[
|
| 80 |
+
base_features,
|
| 81 |
+
np.delete(stage1_preds, idx, axis=1),
|
| 82 |
+
],
|
| 83 |
+
axis=1,
|
| 84 |
+
)
|
| 85 |
+
best_iter = getattr(model, "best_iteration_", None)
|
| 86 |
+
kwargs = {"num_iteration": best_iter} if best_iter is not None else {}
|
| 87 |
+
preds = model.predict_proba(augmented, **kwargs)[:, 1]
|
| 88 |
+
results[:, idx] = preds
|
| 89 |
+
return results
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def predict(smiles_list: List[str]) -> Dict[str, Dict[str, float]]:
|
| 93 |
+
"""
|
| 94 |
+
Predict toxicity targets for a list of SMILES strings.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
smiles_list (list[str]): SMILES strings
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
dict: {smiles: {target_name: prediction_prob}}
|
| 101 |
+
"""
|
| 102 |
+
set_seed(0)
|
| 103 |
+
manifest = _load_manifest()
|
| 104 |
+
target_names = manifest.get("target_names", TARGET_NAMES)
|
| 105 |
+
feature_config = manifest.get("feature_config", {"type": "ecfp"})
|
| 106 |
+
|
| 107 |
+
featurizer = FingerprintFeaturizer(feature_config)
|
| 108 |
+
batch, features = featurizer.featurize_smiles(smiles_list)
|
| 109 |
+
|
| 110 |
+
stage1_preds = _compute_stage1_predictions(features, target_names)
|
| 111 |
+
stage2_preds = _compute_stage2_predictions(features, stage1_preds, target_names)
|
| 112 |
+
|
| 113 |
+
predictions: Dict[str, Dict[str, float]] = {}
|
| 114 |
+
valid_idx = 0
|
| 115 |
+
|
| 116 |
+
for original_smiles, is_valid in zip(smiles_list, batch.mask):
|
| 117 |
+
if not is_valid:
|
| 118 |
+
predictions[original_smiles] = {target: BASE_PREDICTION for target in target_names}
|
| 119 |
+
continue
|
| 120 |
+
|
| 121 |
+
row_preds = stage2_preds[valid_idx] if stage2_preds.size else np.full(len(target_names), BASE_PREDICTION)
|
| 122 |
+
predictions[original_smiles] = {target: float(score) for target, score in zip(target_names, row_preds)}
|
| 123 |
+
valid_idx += 1
|
| 124 |
+
|
| 125 |
+
return predictions
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn[standard]
|
| 3 |
+
numpy==1.26.2
|
| 4 |
+
python-dotenv
|
| 5 |
+
pandas==2.2.2
|
| 6 |
+
scikit-learn==1.7.1
|
| 7 |
+
pydantic
|
| 8 |
+
rdkit-pypi
|
| 9 |
+
datasets
|
| 10 |
+
lightgbm
|
| 11 |
+
optuna
|
| 12 |
+
joblib
|
| 13 |
+
map4
|
src/constants.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TARGET_NAMES = [
|
| 2 |
+
"NR-AhR",
|
| 3 |
+
"NR-AR",
|
| 4 |
+
"NR-AR-LBD",
|
| 5 |
+
"NR-Aromatase",
|
| 6 |
+
"NR-ER",
|
| 7 |
+
"NR-ER-LBD",
|
| 8 |
+
"NR-PPAR-gamma",
|
| 9 |
+
"SR-ARE",
|
| 10 |
+
"SR-ATAD5",
|
| 11 |
+
"SR-HSE",
|
| 12 |
+
"SR-MMP",
|
| 13 |
+
"SR-p53",
|
| 14 |
+
]
|
| 15 |
+
|
| 16 |
+
CANONICAL_SMILES_COLUMN = "canonical_smiles"
|
src/features.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Dict, Sequence
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from rdkit import DataStructs
|
| 9 |
+
from rdkit.Chem import AllChem
|
| 10 |
+
|
| 11 |
+
from .constants import CANONICAL_SMILES_COLUMN
|
| 12 |
+
from .preprocess import MoleculeBatch, filter_dataframe_by_mask, standardize_smiles
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
from map4 import MAP4Calculator # type: ignore
|
| 16 |
+
except Exception: # pragma: no cover - optional dependency
|
| 17 |
+
MAP4Calculator = None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class FingerprintFeaturizer:
|
| 21 |
+
"""Compute molecular fingerprints with optional caching."""
|
| 22 |
+
|
| 23 |
+
def __init__(self, feature_config: Dict):
|
| 24 |
+
self.config = feature_config
|
| 25 |
+
self.fingerprint_type = feature_config.get("type", "ecfp").lower()
|
| 26 |
+
self.radius = feature_config.get("radius", 2)
|
| 27 |
+
self.n_bits = feature_config.get("n_bits", 1024)
|
| 28 |
+
self.map4_dim = feature_config.get("map4_dim", 1024)
|
| 29 |
+
self.use_counts = feature_config.get("use_counts", False)
|
| 30 |
+
cache_dir = feature_config.get("cache_dir")
|
| 31 |
+
self.cache_dir = Path(cache_dir) if cache_dir else None
|
| 32 |
+
if self.cache_dir:
|
| 33 |
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
| 34 |
+
|
| 35 |
+
def featurize_dataframe(self, df: pd.DataFrame, split_name: str):
|
| 36 |
+
cache_payload = self._load_cache(split_name)
|
| 37 |
+
if cache_payload is not None:
|
| 38 |
+
mask = cache_payload["mask"]
|
| 39 |
+
canonical_smiles = cache_payload["canonical_smiles"].tolist()
|
| 40 |
+
features = cache_payload["features"]
|
| 41 |
+
clean_df = filter_dataframe_by_mask(df, mask, canonical_smiles)
|
| 42 |
+
return clean_df, features
|
| 43 |
+
|
| 44 |
+
batch = standardize_smiles(df["smiles"].tolist())
|
| 45 |
+
clean_df = filter_dataframe_by_mask(df, batch.mask, batch.canonical_smiles)
|
| 46 |
+
features = self._compute_fingerprints(batch.mols)
|
| 47 |
+
|
| 48 |
+
self._write_cache(split_name, batch.mask, batch.canonical_smiles, features)
|
| 49 |
+
return clean_df, features
|
| 50 |
+
|
| 51 |
+
def featurize_smiles(self, smiles: Sequence[str]) -> tuple[MoleculeBatch, np.ndarray]:
|
| 52 |
+
batch = standardize_smiles(smiles)
|
| 53 |
+
features = self._compute_fingerprints(batch.mols)
|
| 54 |
+
return batch, features
|
| 55 |
+
|
| 56 |
+
def _cache_path(self, split_name: str) -> Path | None:
|
| 57 |
+
if self.cache_dir is None:
|
| 58 |
+
return None
|
| 59 |
+
return self.cache_dir / f"{split_name}_{self.fingerprint_type}.npz"
|
| 60 |
+
|
| 61 |
+
def _load_cache(self, split_name: str):
|
| 62 |
+
cache_path = self._cache_path(split_name)
|
| 63 |
+
if cache_path is None or not cache_path.exists():
|
| 64 |
+
return None
|
| 65 |
+
return np.load(cache_path, allow_pickle=True)
|
| 66 |
+
|
| 67 |
+
def _write_cache(self, split_name: str, mask, canonical_smiles, features):
|
| 68 |
+
cache_path = self._cache_path(split_name)
|
| 69 |
+
if cache_path is None:
|
| 70 |
+
return
|
| 71 |
+
np.savez(
|
| 72 |
+
cache_path,
|
| 73 |
+
mask=mask,
|
| 74 |
+
canonical_smiles=np.array(canonical_smiles, dtype=object),
|
| 75 |
+
features=features,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
def _compute_fingerprints(self, mols):
|
| 79 |
+
if not mols:
|
| 80 |
+
dim = self._fingerprint_dimension()
|
| 81 |
+
return np.zeros((0, dim), dtype=np.float32)
|
| 82 |
+
|
| 83 |
+
if self.fingerprint_type == "ecfp":
|
| 84 |
+
return self._compute_ecfp(mols)
|
| 85 |
+
if self.fingerprint_type == "map4":
|
| 86 |
+
return self._compute_map4(mols)
|
| 87 |
+
raise ValueError(f"Unsupported fingerprint type: {self.fingerprint_type}")
|
| 88 |
+
|
| 89 |
+
def _fingerprint_dimension(self) -> int:
|
| 90 |
+
if self.fingerprint_type == "map4":
|
| 91 |
+
return self.map4_dim
|
| 92 |
+
return self.n_bits
|
| 93 |
+
|
| 94 |
+
def _compute_ecfp(self, mols):
|
| 95 |
+
fingerprints = np.zeros((len(mols), self.n_bits), dtype=np.float32)
|
| 96 |
+
for idx, mol in enumerate(mols):
|
| 97 |
+
if self.use_counts:
|
| 98 |
+
fp = AllChem.GetMorganFingerprint(mol, self.radius)
|
| 99 |
+
arr = np.zeros(self.n_bits, dtype=np.float32)
|
| 100 |
+
for bit, value in fp.GetNonzeroElements().items():
|
| 101 |
+
arr[bit % self.n_bits] += value
|
| 102 |
+
else:
|
| 103 |
+
bitvect = AllChem.GetMorganFingerprintAsBitVect(
|
| 104 |
+
mol,
|
| 105 |
+
self.radius,
|
| 106 |
+
nBits=self.n_bits,
|
| 107 |
+
)
|
| 108 |
+
arr = np.zeros(self.n_bits, dtype=np.float32)
|
| 109 |
+
DataStructs.ConvertToNumpyArray(bitvect, arr)
|
| 110 |
+
fingerprints[idx] = arr
|
| 111 |
+
return fingerprints
|
| 112 |
+
|
| 113 |
+
def _compute_map4(self, mols):
|
| 114 |
+
if MAP4Calculator is None:
|
| 115 |
+
raise ImportError(
|
| 116 |
+
"MAP4 fingerprint requested but the `map4` package is not installed. "
|
| 117 |
+
"Install it via `pip install map4` or switch features.type to 'ecfp'."
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
calc = MAP4Calculator(dimensions=self.map4_dim)
|
| 121 |
+
fingerprints = np.zeros((len(mols), self.map4_dim), dtype=np.float32)
|
| 122 |
+
for idx, mol in enumerate(mols):
|
| 123 |
+
vec = np.array(calc.calculate(mol), dtype=np.float32)
|
| 124 |
+
fingerprints[idx] = vec
|
| 125 |
+
return fingerprints
|
src/lightgbm_trainer.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Dict, Optional, Sequence
|
| 7 |
+
|
| 8 |
+
import joblib
|
| 9 |
+
import lightgbm as lgb
|
| 10 |
+
import numpy as np
|
| 11 |
+
import optuna
|
| 12 |
+
import pandas as pd
|
| 13 |
+
from sklearn.metrics import roc_auc_score
|
| 14 |
+
|
| 15 |
+
from .constants import TARGET_NAMES
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class TaskTrainingOutput:
|
| 20 |
+
model: lgb.LGBMClassifier
|
| 21 |
+
val_auc: float
|
| 22 |
+
best_iteration: int
|
| 23 |
+
best_params: Dict
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _sample_hyperparams(trial: optuna.Trial, base_params: Dict) -> Dict:
|
| 27 |
+
params = dict(base_params)
|
| 28 |
+
params.update(
|
| 29 |
+
{
|
| 30 |
+
"learning_rate": trial.suggest_float("learning_rate", 1e-3, 0.2, log=True),
|
| 31 |
+
"num_leaves": trial.suggest_int("num_leaves", 16, 256, log=True),
|
| 32 |
+
"max_depth": trial.suggest_int("max_depth", -1, 12),
|
| 33 |
+
"min_child_samples": trial.suggest_int("min_child_samples", 10, 200),
|
| 34 |
+
"feature_fraction": trial.suggest_float("feature_fraction", 0.5, 1.0),
|
| 35 |
+
"bagging_fraction": trial.suggest_float("bagging_fraction", 0.5, 1.0),
|
| 36 |
+
"bagging_freq": trial.suggest_int("bagging_freq", 1, 10),
|
| 37 |
+
"reg_alpha": trial.suggest_float("reg_alpha", 1e-8, 10.0, log=True),
|
| 38 |
+
"reg_lambda": trial.suggest_float("reg_lambda", 1e-8, 10.0, log=True),
|
| 39 |
+
}
|
| 40 |
+
)
|
| 41 |
+
params.setdefault("objective", "binary")
|
| 42 |
+
params.setdefault("metric", "auc")
|
| 43 |
+
params.setdefault("verbosity", -1)
|
| 44 |
+
params.setdefault("boosting_type", "gbdt")
|
| 45 |
+
params.setdefault("n_jobs", -1)
|
| 46 |
+
return params
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def train_lightgbm_task(
|
| 50 |
+
X_train: np.ndarray,
|
| 51 |
+
y_train: np.ndarray,
|
| 52 |
+
X_val: np.ndarray,
|
| 53 |
+
y_val: np.ndarray,
|
| 54 |
+
base_params: Dict,
|
| 55 |
+
boosting_rounds: int,
|
| 56 |
+
early_stopping_rounds: int,
|
| 57 |
+
n_trials: int,
|
| 58 |
+
seed: int,
|
| 59 |
+
) -> Optional[TaskTrainingOutput]:
|
| 60 |
+
if len(np.unique(y_train)) < 2 or len(np.unique(y_val)) < 2:
|
| 61 |
+
return None
|
| 62 |
+
|
| 63 |
+
def objective(trial: optuna.Trial) -> float:
|
| 64 |
+
params = _sample_hyperparams(trial, base_params)
|
| 65 |
+
params["n_estimators"] = boosting_rounds
|
| 66 |
+
params["random_state"] = seed
|
| 67 |
+
model = lgb.LGBMClassifier(**params)
|
| 68 |
+
model.fit(
|
| 69 |
+
X_train,
|
| 70 |
+
y_train,
|
| 71 |
+
eval_set=[(X_val, y_val)],
|
| 72 |
+
eval_metric="auc",
|
| 73 |
+
callbacks=[
|
| 74 |
+
lgb.early_stopping(
|
| 75 |
+
early_stopping_rounds,
|
| 76 |
+
first_metric_only=True,
|
| 77 |
+
verbose=False,
|
| 78 |
+
)
|
| 79 |
+
],
|
| 80 |
+
verbose=False,
|
| 81 |
+
)
|
| 82 |
+
best_iter = getattr(model, "best_iteration_", boosting_rounds)
|
| 83 |
+
preds = model.predict_proba(X_val, num_iteration=best_iter)[:, 1]
|
| 84 |
+
return float(roc_auc_score(y_val, preds))
|
| 85 |
+
|
| 86 |
+
study = optuna.create_study(direction="maximize")
|
| 87 |
+
study.optimize(objective, n_trials=n_trials, show_progress_bar=False)
|
| 88 |
+
|
| 89 |
+
best_params = _sample_hyperparams(study.best_trial, base_params)
|
| 90 |
+
best_params["n_estimators"] = boosting_rounds
|
| 91 |
+
best_params["random_state"] = seed
|
| 92 |
+
|
| 93 |
+
final_model = lgb.LGBMClassifier(**best_params)
|
| 94 |
+
final_model.fit(
|
| 95 |
+
X_train,
|
| 96 |
+
y_train,
|
| 97 |
+
eval_set=[(X_val, y_val)],
|
| 98 |
+
eval_metric="auc",
|
| 99 |
+
callbacks=[
|
| 100 |
+
lgb.early_stopping(
|
| 101 |
+
early_stopping_rounds,
|
| 102 |
+
first_metric_only=True,
|
| 103 |
+
verbose=False,
|
| 104 |
+
)
|
| 105 |
+
],
|
| 106 |
+
verbose=False,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
best_iteration = getattr(final_model, "best_iteration_", boosting_rounds)
|
| 110 |
+
val_preds = final_model.predict_proba(X_val, num_iteration=best_iteration)[:, 1]
|
| 111 |
+
val_auc = roc_auc_score(y_val, val_preds)
|
| 112 |
+
|
| 113 |
+
return TaskTrainingOutput(
|
| 114 |
+
model=final_model,
|
| 115 |
+
val_auc=float(val_auc),
|
| 116 |
+
best_iteration=int(best_iteration),
|
| 117 |
+
best_params=best_params,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def save_stage_metrics(metrics: Dict, path: Path):
|
| 122 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 123 |
+
with path.open("w", encoding="utf-8") as f:
|
| 124 |
+
json.dump(metrics, f, indent=2)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def train_stage_one_models(
|
| 128 |
+
train_features: np.ndarray,
|
| 129 |
+
val_features: Optional[np.ndarray],
|
| 130 |
+
train_df: pd.DataFrame,
|
| 131 |
+
val_df: Optional[pd.DataFrame],
|
| 132 |
+
config: Dict,
|
| 133 |
+
checkpoint_dir: Path,
|
| 134 |
+
target_names: Sequence[str] = TARGET_NAMES,
|
| 135 |
+
) -> Dict:
|
| 136 |
+
stage_dir = checkpoint_dir / "stage1"
|
| 137 |
+
stage_dir.mkdir(parents=True, exist_ok=True)
|
| 138 |
+
|
| 139 |
+
training_cfg = config.get("training", {})
|
| 140 |
+
base_params = training_cfg.get("lightgbm_params", {})
|
| 141 |
+
n_trials = training_cfg.get("optuna_trials", 40)
|
| 142 |
+
boosting_rounds = training_cfg.get("boosting_rounds", 1500)
|
| 143 |
+
early_stopping = training_cfg.get("early_stopping_rounds", 100)
|
| 144 |
+
seed = config.get("seed", 42)
|
| 145 |
+
|
| 146 |
+
n_train = len(train_df)
|
| 147 |
+
n_tasks = len(target_names)
|
| 148 |
+
|
| 149 |
+
train_preds = np.full((n_train, n_tasks), 0.5, dtype=np.float32)
|
| 150 |
+
val_preds = (
|
| 151 |
+
np.full((len(val_df), n_tasks), 0.5, dtype=np.float32)
|
| 152 |
+
if val_df is not None and val_features is not None
|
| 153 |
+
else None
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
metrics: Dict[str, Dict] = {}
|
| 157 |
+
params_dump: Dict[str, Dict] = {}
|
| 158 |
+
|
| 159 |
+
for task_idx, task_name in enumerate(target_names):
|
| 160 |
+
train_mask = train_df[task_name].notna().values
|
| 161 |
+
if val_df is None or val_features is None:
|
| 162 |
+
metrics[task_name] = {"status": "skipped", "reason": "missing validation split"}
|
| 163 |
+
continue
|
| 164 |
+
|
| 165 |
+
val_mask = val_df[task_name].notna().values
|
| 166 |
+
if train_mask.sum() < 2 or val_mask.sum() < 2:
|
| 167 |
+
metrics[task_name] = {"status": "skipped", "reason": "insufficient labeled data"}
|
| 168 |
+
continue
|
| 169 |
+
|
| 170 |
+
X_train_task = train_features[train_mask]
|
| 171 |
+
y_train_task = train_df.loc[train_mask, task_name].astype(float).values
|
| 172 |
+
X_val_task = val_features[val_mask]
|
| 173 |
+
y_val_task = val_df.loc[val_mask, task_name].astype(float).values
|
| 174 |
+
|
| 175 |
+
if len(np.unique(y_train_task)) < 2 or len(np.unique(y_val_task)) < 2:
|
| 176 |
+
metrics[task_name] = {"status": "skipped", "reason": "single-class labels"}
|
| 177 |
+
continue
|
| 178 |
+
|
| 179 |
+
task_result = train_lightgbm_task(
|
| 180 |
+
X_train_task,
|
| 181 |
+
y_train_task,
|
| 182 |
+
X_val_task,
|
| 183 |
+
y_val_task,
|
| 184 |
+
base_params=base_params,
|
| 185 |
+
boosting_rounds=boosting_rounds,
|
| 186 |
+
early_stopping_rounds=early_stopping,
|
| 187 |
+
n_trials=n_trials,
|
| 188 |
+
seed=seed,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
if task_result is None:
|
| 192 |
+
metrics[task_name] = {"status": "skipped", "reason": "training failed"}
|
| 193 |
+
continue
|
| 194 |
+
|
| 195 |
+
model = task_result.model
|
| 196 |
+
best_iter = task_result.best_iteration
|
| 197 |
+
|
| 198 |
+
model_path = stage_dir / f"{task_name}.pkl"
|
| 199 |
+
joblib.dump(model, model_path)
|
| 200 |
+
|
| 201 |
+
params_dump[task_name] = {
|
| 202 |
+
**task_result.best_params,
|
| 203 |
+
"best_iteration": best_iter,
|
| 204 |
+
"val_auc": task_result.val_auc,
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
full_train_preds = model.predict_proba(
|
| 208 |
+
train_features,
|
| 209 |
+
num_iteration=best_iter,
|
| 210 |
+
)[:, 1]
|
| 211 |
+
train_preds[:, task_idx] = full_train_preds.astype(np.float32)
|
| 212 |
+
|
| 213 |
+
if val_preds is not None:
|
| 214 |
+
full_val_preds = model.predict_proba(
|
| 215 |
+
val_features,
|
| 216 |
+
num_iteration=best_iter,
|
| 217 |
+
)[:, 1]
|
| 218 |
+
val_preds[:, task_idx] = full_val_preds.astype(np.float32)
|
| 219 |
+
|
| 220 |
+
metrics[task_name] = {
|
| 221 |
+
"val_auc": task_result.val_auc,
|
| 222 |
+
"n_train_samples": int(train_mask.sum()),
|
| 223 |
+
"n_val_samples": int(val_mask.sum()),
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
save_stage_metrics(metrics, checkpoint_dir / "metrics_stage1.json")
|
| 227 |
+
params_path = checkpoint_dir / "stage1_params.json"
|
| 228 |
+
with params_path.open("w", encoding="utf-8") as f:
|
| 229 |
+
json.dump(params_dump, f, indent=2)
|
| 230 |
+
|
| 231 |
+
return {
|
| 232 |
+
"train_full": train_preds,
|
| 233 |
+
"val_full": val_preds,
|
| 234 |
+
"metrics": metrics,
|
| 235 |
+
}
|
src/model.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch_geometric.nn import GINConv, global_add_pool, global_mean_pool
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class GIN(torch.nn.Module):
|
| 9 |
+
def __init__(self, num_features, num_classes, dropout, hidden_dim=128, num_layers=5, add_or_mean="add"):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.num_layers = num_layers
|
| 12 |
+
self.hidden_dim = hidden_dim
|
| 13 |
+
self.add_or_mean = add_or_mean
|
| 14 |
+
self.dropout = dropout
|
| 15 |
+
|
| 16 |
+
self.conv_layers = nn.ModuleList()
|
| 17 |
+
|
| 18 |
+
# input features → hidden_dim
|
| 19 |
+
mlp = nn.Sequential(
|
| 20 |
+
nn.Linear(num_features, hidden_dim),
|
| 21 |
+
nn.ReLU(),
|
| 22 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 23 |
+
nn.BatchNorm1d(hidden_dim)
|
| 24 |
+
)
|
| 25 |
+
self.conv_layers.append(GINConv(mlp, train_eps=True))
|
| 26 |
+
|
| 27 |
+
# hidden GIN layers
|
| 28 |
+
for _ in range(num_layers - 1):
|
| 29 |
+
mlp = nn.Sequential(
|
| 30 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 31 |
+
nn.ReLU(),
|
| 32 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 33 |
+
nn.BatchNorm1d(hidden_dim)
|
| 34 |
+
)
|
| 35 |
+
self.conv_layers.append(GINConv(mlp, train_eps=True))
|
| 36 |
+
|
| 37 |
+
# Final classifier (after pooling)
|
| 38 |
+
self.fc = nn.Linear(hidden_dim, num_classes)
|
| 39 |
+
|
| 40 |
+
def forward(self, x, edge_index, batch):
|
| 41 |
+
for conv in self.conv_layers:
|
| 42 |
+
x = conv(x, edge_index)
|
| 43 |
+
x = F.relu(x)
|
| 44 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 45 |
+
# Pool to get graph-level representation
|
| 46 |
+
if self.add_or_mean == "mean":
|
| 47 |
+
x = global_mean_pool(x, batch)
|
| 48 |
+
elif self.add_or_mean == "add":
|
| 49 |
+
x = global_add_pool(x, batch)
|
| 50 |
+
|
| 51 |
+
x = F.dropout(x, p=0.5, training=self.training)
|
| 52 |
+
return self.fc(x)
|
src/preprocess.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Dict, List, Sequence
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from datasets import load_dataset
|
| 9 |
+
from rdkit import Chem
|
| 10 |
+
from rdkit.Chem.MolStandardize import rdMolStandardize
|
| 11 |
+
|
| 12 |
+
from .constants import CANONICAL_SMILES_COLUMN
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class MoleculeBatch:
|
| 16 |
+
mols: List[Chem.Mol]
|
| 17 |
+
mask: np.ndarray
|
| 18 |
+
canonical_smiles: List[str]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def load_tox21_dataset(token: str | None, dataset_name: str) -> Dict[str, pd.DataFrame]:
|
| 22 |
+
"""Load dataset splits from Hugging Face into pandas DataFrames."""
|
| 23 |
+
dataset = load_dataset(dataset_name, token=token)
|
| 24 |
+
splits: Dict[str, pd.DataFrame] = {}
|
| 25 |
+
for split_name in dataset.keys():
|
| 26 |
+
splits[split_name] = dataset[split_name].to_pandas()
|
| 27 |
+
return splits
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def standardize_smiles(smiles: Sequence[str]) -> MoleculeBatch:
|
| 31 |
+
"""Standardize SMILES strings and return RDKit molecules with canonical SMILES."""
|
| 32 |
+
tautomer_enumerator = rdMolStandardize.TautomerEnumerator()
|
| 33 |
+
cleanup_params = rdMolStandardize.CleanupParameters()
|
| 34 |
+
|
| 35 |
+
mols: List[Chem.Mol] = []
|
| 36 |
+
canonical_smiles: List[str] = []
|
| 37 |
+
mask = np.zeros(len(smiles), dtype=bool)
|
| 38 |
+
|
| 39 |
+
for idx, smi in enumerate(smiles):
|
| 40 |
+
try:
|
| 41 |
+
mol = Chem.MolFromSmiles(smi)
|
| 42 |
+
if mol is None:
|
| 43 |
+
continue
|
| 44 |
+
|
| 45 |
+
mol = rdMolStandardize.Cleanup(mol, cleanup_params)
|
| 46 |
+
mol = tautomer_enumerator.Canonicalize(mol)
|
| 47 |
+
canonical = Chem.MolToSmiles(mol)
|
| 48 |
+
mol = Chem.MolFromSmiles(canonical)
|
| 49 |
+
if mol is None:
|
| 50 |
+
continue
|
| 51 |
+
|
| 52 |
+
mols.append(mol)
|
| 53 |
+
canonical_smiles.append(canonical)
|
| 54 |
+
mask[idx] = True
|
| 55 |
+
except Exception:
|
| 56 |
+
continue
|
| 57 |
+
|
| 58 |
+
return MoleculeBatch(mols=mols, mask=mask, canonical_smiles=canonical_smiles)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def filter_dataframe_by_mask(df: pd.DataFrame, mask: np.ndarray, canonical_smiles: Sequence[str]) -> pd.DataFrame:
|
| 62 |
+
"""Apply mask to dataframe and append canonical SMILES column."""
|
| 63 |
+
clean_df = df.loc[mask].copy().reset_index(drop=True)
|
| 64 |
+
clean_df[CANONICAL_SMILES_COLUMN] = canonical_smiles
|
| 65 |
+
return clean_df
|
src/seed.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def set_seed(seed: int = 42):
|
| 8 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
| 9 |
+
random.seed(seed)
|
| 10 |
+
np.random.seed(seed)
|
src/stage_two.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Dict, Optional, Sequence
|
| 5 |
+
|
| 6 |
+
import joblib
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from sklearn.metrics import roc_auc_score
|
| 10 |
+
|
| 11 |
+
from .constants import TARGET_NAMES
|
| 12 |
+
from .lightgbm_trainer import save_stage_metrics, train_lightgbm_task
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _build_augmented_matrix(base_features: np.ndarray, prediction_matrix: np.ndarray, target_idx: int) -> np.ndarray:
|
| 16 |
+
mask = np.ones(prediction_matrix.shape[1], dtype=bool)
|
| 17 |
+
mask[target_idx] = False
|
| 18 |
+
return np.concatenate([base_features, prediction_matrix[:, mask]], axis=1)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def train_stage_two_models(
|
| 22 |
+
train_features: np.ndarray,
|
| 23 |
+
val_features: Optional[np.ndarray],
|
| 24 |
+
train_df: pd.DataFrame,
|
| 25 |
+
val_df: Optional[pd.DataFrame],
|
| 26 |
+
config: Dict,
|
| 27 |
+
checkpoint_dir: Path,
|
| 28 |
+
stage1_train_preds: np.ndarray,
|
| 29 |
+
stage1_val_preds: Optional[np.ndarray],
|
| 30 |
+
target_names: Sequence[str] = TARGET_NAMES,
|
| 31 |
+
) -> Dict:
|
| 32 |
+
training_cfg = config.get("training", {})
|
| 33 |
+
base_params = training_cfg.get("lightgbm_params", {})
|
| 34 |
+
n_trials = training_cfg.get("optuna_trials", 40)
|
| 35 |
+
boosting_rounds = training_cfg.get("boosting_rounds", 1500)
|
| 36 |
+
early_stopping = training_cfg.get("early_stopping_rounds", 100)
|
| 37 |
+
seed = config.get("seed", 42)
|
| 38 |
+
|
| 39 |
+
stage_dir = checkpoint_dir / "stage2"
|
| 40 |
+
stage_dir.mkdir(parents=True, exist_ok=True)
|
| 41 |
+
|
| 42 |
+
n_train = len(train_df)
|
| 43 |
+
n_val = len(val_df) if val_df is not None else 0
|
| 44 |
+
|
| 45 |
+
metrics: Dict[str, Dict] = {}
|
| 46 |
+
|
| 47 |
+
for task_idx, task_name in enumerate(target_names):
|
| 48 |
+
mask = train_df[task_name].notna().values
|
| 49 |
+
if mask.sum() == 0:
|
| 50 |
+
metrics[task_name] = {"status": "skipped", "reason": "no labels"}
|
| 51 |
+
continue
|
| 52 |
+
|
| 53 |
+
augmented_train_matrix = _build_augmented_matrix(
|
| 54 |
+
train_features[mask],
|
| 55 |
+
stage1_train_preds[mask],
|
| 56 |
+
task_idx,
|
| 57 |
+
)
|
| 58 |
+
y_train = train_df.loc[mask, task_name].astype(float).values
|
| 59 |
+
|
| 60 |
+
if (
|
| 61 |
+
val_features is None
|
| 62 |
+
or val_df is None
|
| 63 |
+
or stage1_val_preds is None
|
| 64 |
+
or val_df[task_name].notna().sum() < 2
|
| 65 |
+
):
|
| 66 |
+
metrics[task_name] = {"status": "skipped", "reason": "missing validation data"}
|
| 67 |
+
continue
|
| 68 |
+
|
| 69 |
+
val_mask = val_df[task_name].notna().values
|
| 70 |
+
augmented_val_matrix = _build_augmented_matrix(
|
| 71 |
+
val_features[val_mask],
|
| 72 |
+
stage1_val_preds[val_mask],
|
| 73 |
+
task_idx,
|
| 74 |
+
)
|
| 75 |
+
y_val = val_df.loc[val_mask, task_name].astype(float).values
|
| 76 |
+
|
| 77 |
+
if len(np.unique(y_val)) < 2 or len(np.unique(y_train)) < 2:
|
| 78 |
+
metrics[task_name] = {"status": "skipped", "reason": "single-class labels"}
|
| 79 |
+
continue
|
| 80 |
+
|
| 81 |
+
task_result = train_lightgbm_task(
|
| 82 |
+
augmented_train_matrix,
|
| 83 |
+
y_train,
|
| 84 |
+
augmented_val_matrix,
|
| 85 |
+
y_val,
|
| 86 |
+
base_params=base_params,
|
| 87 |
+
boosting_rounds=boosting_rounds,
|
| 88 |
+
early_stopping_rounds=early_stopping,
|
| 89 |
+
n_trials=n_trials,
|
| 90 |
+
seed=seed,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
if task_result is None:
|
| 94 |
+
metrics[task_name] = {"status": "skipped", "reason": "training failed"}
|
| 95 |
+
continue
|
| 96 |
+
|
| 97 |
+
model_path = stage_dir / f"{task_name}.pkl"
|
| 98 |
+
joblib.dump(task_result.model, model_path)
|
| 99 |
+
|
| 100 |
+
metrics[task_name] = {
|
| 101 |
+
"val_auc": task_result.val_auc,
|
| 102 |
+
"best_iteration": int(task_result.best_iteration),
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
save_stage_metrics(metrics, checkpoint_dir / "metrics_stage2.json")
|
| 106 |
+
return {"metrics": metrics}
|
src/train_evaluate.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
from sklearn.metrics import roc_auc_score
|
| 5 |
+
|
| 6 |
+
def masked_bce_loss(logits, labels, mask):
|
| 7 |
+
"""
|
| 8 |
+
logits: [batch_size, num_classes] (raw outputs)
|
| 9 |
+
labels: [batch_size, num_classes] (0/1 with filler)
|
| 10 |
+
mask: [batch_size, num_classes] (True if label is valid)
|
| 11 |
+
"""
|
| 12 |
+
criterion = nn.BCEWithLogitsLoss(reduction="none")
|
| 13 |
+
loss_raw = criterion(logits, labels)
|
| 14 |
+
loss = (loss_raw * mask.float()).sum() / mask.float().sum()
|
| 15 |
+
return loss
|
| 16 |
+
|
| 17 |
+
def train_model(model, loader, optimizer, device):
|
| 18 |
+
model.train()
|
| 19 |
+
total_loss = 0
|
| 20 |
+
for batch in loader:
|
| 21 |
+
batch = batch.to(device)
|
| 22 |
+
|
| 23 |
+
optimizer.zero_grad()
|
| 24 |
+
out = model(batch.x, batch.edge_index, batch.batch) # [num_graphs, num_classes]
|
| 25 |
+
|
| 26 |
+
loss = masked_bce_loss(out, batch.y, batch.mask)
|
| 27 |
+
loss.backward()
|
| 28 |
+
optimizer.step()
|
| 29 |
+
|
| 30 |
+
total_loss += loss.item() * batch.num_graphs
|
| 31 |
+
return total_loss / len(loader.dataset)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@torch.no_grad()
|
| 35 |
+
def evaluate(model, loader, device):
|
| 36 |
+
model.eval()
|
| 37 |
+
total_loss = 0
|
| 38 |
+
for batch in loader:
|
| 39 |
+
batch = batch.to(device)
|
| 40 |
+
out = model(batch.x, batch.edge_index, batch.batch)
|
| 41 |
+
loss = masked_bce_loss(out, batch.y, batch.mask)
|
| 42 |
+
total_loss += loss.item() * batch.num_graphs
|
| 43 |
+
return total_loss / len(loader.dataset)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@torch.no_grad()
|
| 47 |
+
def compute_roc_auc(model, loader, device):
|
| 48 |
+
model.eval()
|
| 49 |
+
y_true, y_pred, y_mask = [], [], []
|
| 50 |
+
|
| 51 |
+
for batch in loader:
|
| 52 |
+
batch = batch.to(device)
|
| 53 |
+
out = model(batch.x, batch.edge_index, batch.batch)
|
| 54 |
+
|
| 55 |
+
# Store predictions (sigmoid → probabilities)
|
| 56 |
+
y_pred.append(torch.sigmoid(out).cpu())
|
| 57 |
+
y_true.append(batch.y.cpu())
|
| 58 |
+
y_mask.append(batch.mask.cpu())
|
| 59 |
+
|
| 60 |
+
# Concatenate across all batches
|
| 61 |
+
y_true = torch.cat(y_true, dim=0).numpy()
|
| 62 |
+
y_pred = torch.cat(y_pred, dim=0).numpy()
|
| 63 |
+
y_mask = torch.cat(y_mask, dim=0).numpy()
|
| 64 |
+
|
| 65 |
+
auc_list = []
|
| 66 |
+
for i in range(y_true.shape[1]): # per label
|
| 67 |
+
mask_i = y_mask[:, i].astype(bool)
|
| 68 |
+
if mask_i.sum() > 0: # at least one valid label
|
| 69 |
+
try:
|
| 70 |
+
auc = roc_auc_score(y_true[mask_i, i], y_pred[mask_i, i])
|
| 71 |
+
auc_list.append(auc)
|
| 72 |
+
except ValueError:
|
| 73 |
+
# happens if only one class present (all 0 or all 1)
|
| 74 |
+
pass
|
| 75 |
+
|
| 76 |
+
return np.mean(auc_list) if len(auc_list) > 0 else float("nan")
|
| 77 |
+
|
| 78 |
+
@torch.no_grad()
|
| 79 |
+
def compute_roc_auc_avg_and_per_class(model, loader, device):
|
| 80 |
+
model.eval()
|
| 81 |
+
y_true, y_pred, y_mask = [], [], []
|
| 82 |
+
|
| 83 |
+
with torch.no_grad():
|
| 84 |
+
for batch in loader:
|
| 85 |
+
batch = batch.to(device)
|
| 86 |
+
out = model(batch.x, batch.edge_index, batch.batch)
|
| 87 |
+
|
| 88 |
+
# Store predictions (sigmoid → probabilities)
|
| 89 |
+
y_pred.append(torch.sigmoid(out).cpu())
|
| 90 |
+
y_true.append(batch.y.cpu())
|
| 91 |
+
y_mask.append(batch.mask.cpu())
|
| 92 |
+
|
| 93 |
+
# Concatenate across all batches
|
| 94 |
+
y_true = torch.cat(y_true, dim=0).numpy()
|
| 95 |
+
y_pred = torch.cat(y_pred, dim=0).numpy()
|
| 96 |
+
y_mask = torch.cat(y_mask, dim=0).numpy()
|
| 97 |
+
|
| 98 |
+
# Compute AUC per class
|
| 99 |
+
auc_list = []
|
| 100 |
+
for i in range(y_true.shape[1]):
|
| 101 |
+
mask_i = y_mask[:, i].astype(bool)
|
| 102 |
+
if mask_i.sum() > 0:
|
| 103 |
+
try:
|
| 104 |
+
auc = roc_auc_score(y_true[mask_i, i], y_pred[mask_i, i])
|
| 105 |
+
except ValueError:
|
| 106 |
+
auc = np.nan # in case only one class present
|
| 107 |
+
else:
|
| 108 |
+
auc = np.nan
|
| 109 |
+
auc_list.append(auc)
|
| 110 |
+
|
| 111 |
+
# Convert to numpy array for easier manipulation
|
| 112 |
+
auc_array = np.array(auc_list, dtype=np.float32)
|
| 113 |
+
mean_auc = np.nanmean(auc_array) # overall mean ignoring NaNs
|
| 114 |
+
|
| 115 |
+
# Return both per-class and mean
|
| 116 |
+
return auc_array, mean_auc
|
train.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Dict
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
|
| 11 |
+
from src.constants import TARGET_NAMES
|
| 12 |
+
from src.features import FingerprintFeaturizer
|
| 13 |
+
from src.lightgbm_trainer import train_stage_one_models
|
| 14 |
+
from src.preprocess import load_tox21_dataset
|
| 15 |
+
from src.seed import set_seed
|
| 16 |
+
from src.stage_two import train_stage_two_models
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _default_checkpoint_dir(config: Dict) -> Path:
|
| 20 |
+
checkpoint_cfg = config.get("output", {})
|
| 21 |
+
checkpoint_dir = checkpoint_cfg.get("checkpoint_dir", "./checkpoints")
|
| 22 |
+
path = Path(checkpoint_dir)
|
| 23 |
+
path.mkdir(parents=True, exist_ok=True)
|
| 24 |
+
return path
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def train(config: Dict):
|
| 28 |
+
load_dotenv()
|
| 29 |
+
set_seed(config.get("seed", 42))
|
| 30 |
+
token = os.getenv("TOKEN")
|
| 31 |
+
|
| 32 |
+
dataset_cfg = config.get("dataset", {})
|
| 33 |
+
dataset_name = dataset_cfg.get("name", "ml-jku/tox21")
|
| 34 |
+
splits = load_tox21_dataset(token, dataset_name)
|
| 35 |
+
|
| 36 |
+
if "train" not in splits or "validation" not in splits:
|
| 37 |
+
raise ValueError("Dataset must provide 'train' and 'validation' splits.")
|
| 38 |
+
|
| 39 |
+
featurizer = FingerprintFeaturizer(config.get("features", {}))
|
| 40 |
+
train_df, train_features = featurizer.featurize_dataframe(splits["train"], "train")
|
| 41 |
+
val_df, val_features = featurizer.featurize_dataframe(splits["validation"], "validation")
|
| 42 |
+
|
| 43 |
+
checkpoint_dir = _default_checkpoint_dir(config)
|
| 44 |
+
cache_dir = checkpoint_dir / "cache"
|
| 45 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
| 46 |
+
|
| 47 |
+
print("==== Stage 1: Training baseline LightGBM models ====")
|
| 48 |
+
stage1_artifacts = train_stage_one_models(
|
| 49 |
+
train_features,
|
| 50 |
+
val_features,
|
| 51 |
+
train_df,
|
| 52 |
+
val_df,
|
| 53 |
+
config,
|
| 54 |
+
checkpoint_dir,
|
| 55 |
+
target_names=TARGET_NAMES,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
stage1_train_full = stage1_artifacts["train_full"]
|
| 59 |
+
stage1_val_full = stage1_artifacts["val_full"]
|
| 60 |
+
|
| 61 |
+
np.savez(
|
| 62 |
+
cache_dir / "stage1_train_predictions.npz",
|
| 63 |
+
full=stage1_train_full,
|
| 64 |
+
target_names=np.array(TARGET_NAMES, dtype=object),
|
| 65 |
+
)
|
| 66 |
+
if stage1_val_full is not None:
|
| 67 |
+
np.savez(
|
| 68 |
+
cache_dir / "stage1_validation_predictions.npz",
|
| 69 |
+
full=stage1_val_full,
|
| 70 |
+
target_names=np.array(TARGET_NAMES, dtype=object),
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
stage2_metrics = None
|
| 74 |
+
multitask_cfg = config.get("multitask", {"enabled": False})
|
| 75 |
+
if multitask_cfg.get("enabled", False):
|
| 76 |
+
print("==== Stage 2: Training multitask-augmented LightGBM models ====")
|
| 77 |
+
stage2_artifacts = train_stage_two_models(
|
| 78 |
+
train_features,
|
| 79 |
+
val_features,
|
| 80 |
+
train_df,
|
| 81 |
+
val_df,
|
| 82 |
+
config,
|
| 83 |
+
checkpoint_dir,
|
| 84 |
+
stage1_train_full,
|
| 85 |
+
stage1_val_full,
|
| 86 |
+
target_names=TARGET_NAMES,
|
| 87 |
+
)
|
| 88 |
+
stage2_metrics = stage2_artifacts["metrics"]
|
| 89 |
+
|
| 90 |
+
stage2_entry = {
|
| 91 |
+
"enabled": bool(multitask_cfg.get("enabled", False)),
|
| 92 |
+
"model_dir": str(checkpoint_dir / "stage2") if stage2_metrics is not None else None,
|
| 93 |
+
"metrics": str(checkpoint_dir / "metrics_stage2.json") if stage2_metrics is not None else None,
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
manifest = {
|
| 97 |
+
"feature_config": config.get("features", {}),
|
| 98 |
+
"target_names": TARGET_NAMES,
|
| 99 |
+
"dataset": dataset_cfg,
|
| 100 |
+
"stage1": {
|
| 101 |
+
"model_dir": str(checkpoint_dir / "stage1"),
|
| 102 |
+
"metrics": str((checkpoint_dir / "metrics_stage1.json")),
|
| 103 |
+
},
|
| 104 |
+
"stage2": stage2_entry,
|
| 105 |
+
"multitask": multitask_cfg,
|
| 106 |
+
"seed": config.get("seed", 42),
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
manifest_path = checkpoint_dir / "training_manifest.json"
|
| 110 |
+
with manifest_path.open("w", encoding="utf-8") as f:
|
| 111 |
+
json.dump(manifest, f, indent=2)
|
| 112 |
+
|
| 113 |
+
print("Training complete.")
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
if __name__ == "__main__":
|
| 117 |
+
with open("./config/config.json", "r", encoding="utf-8") as f:
|
| 118 |
+
config = json.load(f)
|
| 119 |
+
train(config)
|