initial commit
Browse files- .gitattributes +1 -0
- .gitignore +1 -0
- LICENSE +191 -0
- LICENSES/MIT-OpenAI-CLIP.txt +21 -0
- LICENSES/MIT-OpenCLIP.txt +25 -0
- NOTICE +60 -0
- README.md +83 -0
- assets/Raon-VisionEncoder-Gradient-Black.png +0 -0
- assets/Raon-VisionEncoder-Gradient-White.png +0 -0
- assets/photo.jpg +3 -0
- config.json +40 -0
- configuration_raonve.py +96 -0
- modeling_raonve.py +235 -0
- raon_vision_encoder/__init__.py +0 -0
- raon_vision_encoder/clip.py +287 -0
- raon_vision_encoder/constants.py +5 -0
- raon_vision_encoder/timm_model.py +397 -0
- raon_vision_encoder/tokenizer.py +193 -0
- raon_vision_encoder/transform.py +44 -0
- raon_vision_encoder/transformer.py +627 -0
- raon_vision_encoder/utils.py +16 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/photo.jpg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
LICENSE
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
Apache License
|
| 3 |
+
Version 2.0, January 2004
|
| 4 |
+
http://www.apache.org/licenses/
|
| 5 |
+
|
| 6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 7 |
+
|
| 8 |
+
1. Definitions.
|
| 9 |
+
|
| 10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 12 |
+
|
| 13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 14 |
+
the copyright owner that is granting the License.
|
| 15 |
+
|
| 16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 17 |
+
other entities that control, are controlled by, or are under common
|
| 18 |
+
control with that entity. For the purposes of this definition,
|
| 19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 20 |
+
direction or management of such entity, whether by contract or
|
| 21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 23 |
+
|
| 24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 25 |
+
exercising permissions granted by this License.
|
| 26 |
+
|
| 27 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 28 |
+
including but not limited to software source code, documentation
|
| 29 |
+
source, and configuration files.
|
| 30 |
+
|
| 31 |
+
"Object" form shall mean any form resulting from mechanical
|
| 32 |
+
transformation or translation of a Source form, including but
|
| 33 |
+
not limited to compiled object code, generated documentation,
|
| 34 |
+
and conversions to other media types.
|
| 35 |
+
|
| 36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 37 |
+
Object form, made available under the License, as indicated by a
|
| 38 |
+
copyright notice that is included in or attached to the work
|
| 39 |
+
(an example is provided in the Appendix below).
|
| 40 |
+
|
| 41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 42 |
+
form, that is based on (or derived from) the Work and for which the
|
| 43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 45 |
+
of this License, Derivative Works shall not include works that remain
|
| 46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 47 |
+
the Work and Derivative Works thereof.
|
| 48 |
+
|
| 49 |
+
"Contribution" shall mean any work of authorship, including
|
| 50 |
+
the original version of the Work and any modifications or additions
|
| 51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 52 |
+
submitted to the Licensor for inclusion in the Work by the copyright owner
|
| 53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 55 |
+
means any form of electronic, verbal, or written communication sent
|
| 56 |
+
to the Licensor or its representatives, including but not limited to
|
| 57 |
+
communication on electronic mailing lists, source code control systems,
|
| 58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 60 |
+
excluding communication that is conspicuously marked or otherwise
|
| 61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 62 |
+
|
| 63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 64 |
+
on behalf of whom a Contribution has been received by the Licensor and
|
| 65 |
+
subsequently incorporated within the Work.
|
| 66 |
+
|
| 67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 72 |
+
Work and such Derivative Works in Source or Object form.
|
| 73 |
+
|
| 74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 77 |
+
(except as stated in this section) patent license to make, have made,
|
| 78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 79 |
+
where such license applies only to those patent claims licensable
|
| 80 |
+
by such Contributor that are necessarily infringed by their
|
| 81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 83 |
+
institute patent litigation against any entity (including a
|
| 84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 85 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 86 |
+
or contributory patent infringement, then any patent licenses
|
| 87 |
+
granted to You under this License for that Work shall terminate
|
| 88 |
+
as of the date such litigation is filed.
|
| 89 |
+
|
| 90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 91 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 92 |
+
modifications, and in Source or Object form, provided that You
|
| 93 |
+
meet the following conditions:
|
| 94 |
+
|
| 95 |
+
(a) You must give any other recipients of the Work or
|
| 96 |
+
Derivative Works a copy of this License; and
|
| 97 |
+
|
| 98 |
+
(b) You must cause any modified files to carry prominent notices
|
| 99 |
+
stating that You changed the files; and
|
| 100 |
+
|
| 101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 102 |
+
that You distribute, all copyright, patent, trademark, and
|
| 103 |
+
attribution notices from the Source form of the Work,
|
| 104 |
+
excluding those notices that do not pertain to any part of
|
| 105 |
+
the Derivative Works; and
|
| 106 |
+
|
| 107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 108 |
+
distribution, then any Derivative Works that You distribute must
|
| 109 |
+
include a readable copy of the attribution notices contained
|
| 110 |
+
within such NOTICE file, excluding any notices that do not
|
| 111 |
+
pertain to any part of the Derivative Works, in at least one
|
| 112 |
+
of the following places: within a NOTICE text file distributed
|
| 113 |
+
as part of the Derivative Works; within the Source form or
|
| 114 |
+
documentation, if provided along with the Derivative Works; or,
|
| 115 |
+
within a display generated by the Derivative Works, if and
|
| 116 |
+
wherever such third-party notices normally appear. The contents
|
| 117 |
+
of the NOTICE file are for informational purposes only and
|
| 118 |
+
do not modify the License. You may add Your own attribution
|
| 119 |
+
notices within Derivative Works that You distribute, alongside
|
| 120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 121 |
+
that such additional attribution notices cannot be construed
|
| 122 |
+
as modifying the License.
|
| 123 |
+
|
| 124 |
+
You may add Your own copyright statement to Your modifications and
|
| 125 |
+
may provide additional or different license terms and conditions
|
| 126 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 127 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 128 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 129 |
+
the conditions stated in this License.
|
| 130 |
+
|
| 131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 133 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 134 |
+
this License, without any additional terms or conditions.
|
| 135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 136 |
+
the terms of any separate license agreement you may have executed
|
| 137 |
+
with Licensor regarding such Contributions.
|
| 138 |
+
|
| 139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 141 |
+
except as required for reasonable and customary use in describing the
|
| 142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 143 |
+
|
| 144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 145 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 148 |
+
implied, including, without limitation, any warranties or conditions
|
| 149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 151 |
+
appropriateness of using or redistributing the Work and assume any
|
| 152 |
+
risks associated with Your exercise of permissions under this License.
|
| 153 |
+
|
| 154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 155 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 156 |
+
unless required by applicable law (such as deliberate and grossly
|
| 157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 158 |
+
liable to You for damages, including any direct, indirect, special,
|
| 159 |
+
incidental, or consequential damages of any character arising as a
|
| 160 |
+
result of this License or out of the use or inability to use the
|
| 161 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 162 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 163 |
+
other commercial damages or losses), even if such Contributor
|
| 164 |
+
has been advised of the possibility of such damages.
|
| 165 |
+
|
| 166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 169 |
+
or other liability obligations and/or rights consistent with this
|
| 170 |
+
License. However, in accepting such obligations, You may act only
|
| 171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 172 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 173 |
+
defend, and hold each Contributor harmless for any liability
|
| 174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 175 |
+
of your accepting any such warranty or additional liability.
|
| 176 |
+
|
| 177 |
+
END OF TERMS AND CONDITIONS
|
| 178 |
+
|
| 179 |
+
Copyright 2024-2026 Raon Vision Team
|
| 180 |
+
|
| 181 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 182 |
+
you may not use this file except in compliance with the License.
|
| 183 |
+
You may obtain a copy of the License at
|
| 184 |
+
|
| 185 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 186 |
+
|
| 187 |
+
Unless required by applicable law or agreed to in writing, software
|
| 188 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 189 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 190 |
+
See the License for the specific language governing permissions and
|
| 191 |
+
limitations under the License.
|
LICENSES/MIT-OpenAI-CLIP.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2021 OpenAI
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
LICENSES/MIT-OpenCLIP.txt
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2012-2021 Gabriel Ilharco, Mitchell Wortsman,
|
| 4 |
+
Nicholas Carlini, Rohan Taori, Achal Dave, Vaishaal Shankar,
|
| 5 |
+
John Miller, Hongseok Namkoong, Hannaneh Hajishirzi, Ali Farhadi,
|
| 6 |
+
Ludwig Schmidt
|
| 7 |
+
|
| 8 |
+
Permission is hereby granted, free of charge, to any person obtaining
|
| 9 |
+
a copy of this software and associated documentation files (the
|
| 10 |
+
"Software"), to deal in the Software without restriction, including
|
| 11 |
+
without limitation the rights to use, copy, modify, merge, publish,
|
| 12 |
+
distribute, sublicense, and/or sell copies of the Software, and to
|
| 13 |
+
permit persons to whom the Software is furnished to do so, subject to
|
| 14 |
+
the following conditions:
|
| 15 |
+
|
| 16 |
+
The above copyright notice and this permission notice shall be
|
| 17 |
+
included in all copies or substantial portions of the Software.
|
| 18 |
+
|
| 19 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
| 20 |
+
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
| 21 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
| 22 |
+
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
| 23 |
+
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
| 24 |
+
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
| 25 |
+
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
NOTICE
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
raon-vision-encoder
|
| 2 |
+
Copyright 2024-2026 Raon Vision Team
|
| 3 |
+
|
| 4 |
+
This product includes software derived from the following projects:
|
| 5 |
+
|
| 6 |
+
===============================================================================
|
| 7 |
+
OpenCLIP
|
| 8 |
+
https://github.com/mlfoundations/open_clip
|
| 9 |
+
Licensed under the MIT License (see LICENSES/MIT-OpenCLIP.txt)
|
| 10 |
+
|
| 11 |
+
Copyright (c) 2012-2021 Gabriel Ilharco, Mitchell Wortsman,
|
| 12 |
+
Nicholas Carlini, Rohan Taori, Achal Dave, Vaishaal Shankar,
|
| 13 |
+
John Miller, Hongseok Namkoong, Hannaneh Hajishirzi, Ali Farhadi,
|
| 14 |
+
Ludwig Schmidt
|
| 15 |
+
|
| 16 |
+
Used in: model/ and train/ packages (LocCa, CLIP, loss, factory,
|
| 17 |
+
transformer, data pipeline, training loop, etc.)
|
| 18 |
+
|
| 19 |
+
===============================================================================
|
| 20 |
+
OpenAI CLIP
|
| 21 |
+
https://github.com/openai/CLIP
|
| 22 |
+
Licensed under the MIT License (see LICENSES/MIT-OpenAI-CLIP.txt)
|
| 23 |
+
|
| 24 |
+
Copyright (c) 2021 OpenAI
|
| 25 |
+
|
| 26 |
+
Used in: model/tokenizer.py, model/bpe_simple_vocab_16e6.txt.gz
|
| 27 |
+
|
| 28 |
+
===============================================================================
|
| 29 |
+
Meta Platforms, Inc. (MAE / MoCo v3)
|
| 30 |
+
Licensed under the MIT License via OpenCLIP
|
| 31 |
+
|
| 32 |
+
Copyright (c) Meta Platforms, Inc. and affiliates
|
| 33 |
+
|
| 34 |
+
Used in: model/pos_embed.py (sincos position embedding utilities)
|
| 35 |
+
|
| 36 |
+
===============================================================================
|
| 37 |
+
timm (pytorch-image-models)
|
| 38 |
+
https://github.com/huggingface/pytorch-image-models
|
| 39 |
+
Licensed under the Apache License 2.0
|
| 40 |
+
|
| 41 |
+
Copyright (c) Ross Wightman
|
| 42 |
+
|
| 43 |
+
Used in: model/transform.py (ResizeKeepRatio)
|
| 44 |
+
|
| 45 |
+
===============================================================================
|
| 46 |
+
References
|
| 47 |
+
|
| 48 |
+
The following papers informed the design and implementation of features
|
| 49 |
+
in this software. Code was independently implemented unless noted above.
|
| 50 |
+
|
| 51 |
+
- CoCa: Yu et al., "CoCa: Contrastive Captioners are Image-Text Foundation Models", 2022
|
| 52 |
+
- SigLIP: Zhai et al., "Sigmoid Loss for Language Image Pre-Training", 2023
|
| 53 |
+
- SigLIP2: Tschannen et al., "SigLIP 2: Multilingual Vision-Language Encoders", 2025
|
| 54 |
+
- DINO: Caron et al., "Emerging Properties in Self-Supervised Vision Transformers", 2021
|
| 55 |
+
- DINOv2: Oquab et al., "DINOv2: Learning Robust Visual Features without Supervision", 2024
|
| 56 |
+
- SILC: Naeem et al., "SILC: Improving Vision Language Pretraining with Self-Distillation", 2023
|
| 57 |
+
- TIPS: Huang et al., "TIPS: Text-Image Pretraining with Spatial Awareness", 2024
|
| 58 |
+
- Koleo: Sablayrolles et al., "Spreading vectors for similarity search", ICLR 2019
|
| 59 |
+
- Gram Anchoring: Simeoni et al., "DINOv3", 2025 (independently implemented)
|
| 60 |
+
- NaFlex: from SigLIP2 / PaLI (independently implemented in PyTorch)
|
README.md
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: transformers
|
| 3 |
+
tags:
|
| 4 |
+
- vision
|
| 5 |
+
- image-text
|
| 6 |
+
- clip
|
| 7 |
+
- zero-shot
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
<div align="center">
|
| 11 |
+
<img class="block dark:hidden" src="assets/Raon-VisionEncoder-Gradient-Black.png" alt="Raon VisionEncoder" width="600">
|
| 12 |
+
<img class="hidden dark:block" src="assets/Raon-VisionEncoder-Gradient-White.png" alt="Raon VisionEncoder" width="600">
|
| 13 |
+
</div>
|
| 14 |
+
|
| 15 |
+
<p align="center">
|
| 16 |
+
<a href="https://www.krafton.ai/ko/"><img src="https://img.shields.io/badge/Homepage-KRAFTON%20AI-blue?style=flat&logo=google-chrome&logoColor=white" alt="Homepage"></a>
|
| 17 |
+
<br>
|
| 18 |
+
<a href="https://huggingface.co/KRAFTON"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-KRAFTON-yellow?style=flat" alt="Hugging Face"></a>
|
| 19 |
+
<a href="https://x.com/Krafton_AI"><img src="https://img.shields.io/badge/X-KRAFTON%20AI-white?style=flat&logo=x&logoColor=black" alt="X"></a>
|
| 20 |
+
<br>
|
| 21 |
+
<a href="https://www.apache.org/licenses/LICENSE-2.0"><img src="https://img.shields.io/badge/License-Apache%202.0-lightgrey?style=flat" alt="License"></a>
|
| 22 |
+
</p>
|
| 23 |
+
|
| 24 |
+
**Raon-VisionEncoder** is a 1.14B-parameter vision-language foundation model by [KRAFTON](https://www.krafton.com) for image and text feature extraction.
|
| 25 |
+
It supports zero-shot image classification, image-text retrieval, and native aspect ratio inference via NaFlex.
|
| 26 |
+
Built on [OpenCLIP](https://github.com/mlfoundations/open_clip) with a LocCa (Localized CoCa) architecture and ViT-SO400M vision encoder.
|
| 27 |
+
|
| 28 |
+
## Pretrained Models
|
| 29 |
+
|
| 30 |
+
| Model | Params (Inference) | Vision | Text | Patch Size | NaFlex Default Patches |
|
| 31 |
+
|-------|--------------------|--------|------|------------|------------------------|
|
| 32 |
+
| LocCa ViT-SO400M-16-SigLIP2 | 1.14B | 0.43B | 0.71B | 16x16 | 256 |
|
| 33 |
+
|
| 34 |
+
## Requirements
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
pip install torch torchvision timm transformers huggingface-hub safetensors ftfy
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
## Quick Start
|
| 41 |
+
|
| 42 |
+
```python
|
| 43 |
+
import torch
|
| 44 |
+
from transformers import AutoModel
|
| 45 |
+
from PIL import Image
|
| 46 |
+
|
| 47 |
+
# Load model + processor
|
| 48 |
+
model = AutoModel.from_pretrained("KRAFTON/Raon-VisionEncoder", trust_remote_code=True)
|
| 49 |
+
model = model.to(dtype=torch.bfloat16).eval()
|
| 50 |
+
processor = model.get_processor("KRAFTON/Raon-VisionEncoder")
|
| 51 |
+
|
| 52 |
+
# Encode image and text
|
| 53 |
+
img_inputs = processor(images=Image.open("assets/photo.jpg"))
|
| 54 |
+
txt_inputs = processor(text=["a cat", "a dog"])
|
| 55 |
+
|
| 56 |
+
with torch.no_grad():
|
| 57 |
+
img_feat = model.encode_image(**img_inputs)
|
| 58 |
+
txt_feat = model.encode_text(**txt_inputs)
|
| 59 |
+
|
| 60 |
+
# Compute similarity with learned scale and bias
|
| 61 |
+
logits = model.logit_scale.exp() * (img_feat @ txt_feat.T) + model.logit_bias
|
| 62 |
+
probs = logits.softmax(dim=-1)
|
| 63 |
+
print(probs)
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
## API Reference
|
| 67 |
+
|
| 68 |
+
| Method | Input | Output |
|
| 69 |
+
|--------|-------|--------|
|
| 70 |
+
| `model.encode_image(**inputs)` | Processor output (image) | `[B, 1152]` normalized image features |
|
| 71 |
+
| `model.encode_text(**inputs)` | Processor output (text) | `[B, 1152]` normalized text features |
|
| 72 |
+
| `model.logit_scale` | - | Learned temperature parameter |
|
| 73 |
+
| `model.logit_bias` | - | Learned bias parameter |
|
| 74 |
+
| `model.get_processor(repo_id)` | HuggingFace repo ID | Processor instance |
|
| 75 |
+
| `processor(images=img)` | PIL Image | Preprocessed image dict |
|
| 76 |
+
| `processor(text=["a cat"])` | list of strings | Tokenized text dict |
|
| 77 |
+
|
| 78 |
+
## License
|
| 79 |
+
|
| 80 |
+
This repository is licensed under the [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0).
|
| 81 |
+
Third-party notices in [NOTICE](NOTICE).
|
| 82 |
+
|
| 83 |
+
© 2026 KRAFTON
|
assets/Raon-VisionEncoder-Gradient-Black.png
ADDED
|
assets/Raon-VisionEncoder-Gradient-White.png
ADDED
|
assets/photo.jpg
ADDED
|
Git LFS Details
|
config.json
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"RaonVEModel"
|
| 4 |
+
],
|
| 5 |
+
"model_type": "raon_ve",
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoConfig": "configuration_raonve.RaonVEConfig",
|
| 8 |
+
"AutoModel": "modeling_raonve.RaonVEModel"
|
| 9 |
+
},
|
| 10 |
+
"embed_dim": 1152,
|
| 11 |
+
"init_logit_bias": -10,
|
| 12 |
+
"vision_config": {
|
| 13 |
+
"image_size": 256,
|
| 14 |
+
"timm_model_name": "vit_so400m_patch16_siglip_256",
|
| 15 |
+
"timm_model_pretrained": false,
|
| 16 |
+
"timm_pool": "map",
|
| 17 |
+
"timm_proj": "none"
|
| 18 |
+
},
|
| 19 |
+
"text_config": {
|
| 20 |
+
"context_length": 64,
|
| 21 |
+
"vocab_size": 256000,
|
| 22 |
+
"hf_tokenizer_name": "timm/ViT-SO400M-16-SigLIP2-256",
|
| 23 |
+
"tokenizer_kwargs": {
|
| 24 |
+
"clean": "canonicalize"
|
| 25 |
+
},
|
| 26 |
+
"width": 1152,
|
| 27 |
+
"heads": 16,
|
| 28 |
+
"layers": 27,
|
| 29 |
+
"mlp_ratio": 3.7362,
|
| 30 |
+
"no_causal_mask": true,
|
| 31 |
+
"proj_bias": true,
|
| 32 |
+
"pool_type": "last",
|
| 33 |
+
"norm_kwargs": {
|
| 34 |
+
"eps": 1e-06
|
| 35 |
+
},
|
| 36 |
+
"act_kwargs": {
|
| 37 |
+
"approximate": "tanh"
|
| 38 |
+
}
|
| 39 |
+
}
|
| 40 |
+
}
|
configuration_raonve.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Raon-VisionEncoder configuration."""
|
| 2 |
+
|
| 3 |
+
from transformers import PretrainedConfig
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class RaonVEVisionConfig(PretrainedConfig):
|
| 7 |
+
model_type = "raon_ve_vision"
|
| 8 |
+
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
image_size=256,
|
| 12 |
+
timm_model_name="vit_so400m_patch16_siglip_256",
|
| 13 |
+
timm_model_pretrained=False,
|
| 14 |
+
timm_pool="map",
|
| 15 |
+
timm_proj="none",
|
| 16 |
+
**kwargs,
|
| 17 |
+
):
|
| 18 |
+
super().__init__(**kwargs)
|
| 19 |
+
self.image_size = image_size
|
| 20 |
+
self.timm_model_name = timm_model_name
|
| 21 |
+
self.timm_model_pretrained = timm_model_pretrained
|
| 22 |
+
self.timm_pool = timm_pool
|
| 23 |
+
self.timm_proj = timm_proj
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class RaonVETextConfig(PretrainedConfig):
|
| 27 |
+
model_type = "raon_ve_text"
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
context_length=64,
|
| 32 |
+
vocab_size=256000,
|
| 33 |
+
width=1152,
|
| 34 |
+
heads=16,
|
| 35 |
+
layers=27,
|
| 36 |
+
mlp_ratio=3.7362,
|
| 37 |
+
no_causal_mask=True,
|
| 38 |
+
proj_bias=True,
|
| 39 |
+
pool_type="last",
|
| 40 |
+
hf_tokenizer_name="timm/ViT-SO400M-16-SigLIP2-256",
|
| 41 |
+
tokenizer_kwargs=None,
|
| 42 |
+
norm_kwargs=None,
|
| 43 |
+
act_kwargs=None,
|
| 44 |
+
**kwargs,
|
| 45 |
+
):
|
| 46 |
+
super().__init__(**kwargs)
|
| 47 |
+
self.context_length = context_length
|
| 48 |
+
self.vocab_size = vocab_size
|
| 49 |
+
self.width = width
|
| 50 |
+
self.heads = heads
|
| 51 |
+
self.layers = layers
|
| 52 |
+
self.mlp_ratio = mlp_ratio
|
| 53 |
+
self.no_causal_mask = no_causal_mask
|
| 54 |
+
self.proj_bias = proj_bias
|
| 55 |
+
self.pool_type = pool_type
|
| 56 |
+
self.hf_tokenizer_name = hf_tokenizer_name
|
| 57 |
+
self.tokenizer_kwargs = tokenizer_kwargs or {"clean": "canonicalize"}
|
| 58 |
+
self.norm_kwargs = norm_kwargs or {"eps": 1e-6}
|
| 59 |
+
self.act_kwargs = act_kwargs or {"approximate": "tanh"}
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class RaonVEConfig(PretrainedConfig):
|
| 63 |
+
model_type = "raon_ve"
|
| 64 |
+
is_composition = True
|
| 65 |
+
|
| 66 |
+
def __init__(
|
| 67 |
+
self,
|
| 68 |
+
embed_dim=1152,
|
| 69 |
+
init_logit_bias=-10,
|
| 70 |
+
vision_config=None,
|
| 71 |
+
text_config=None,
|
| 72 |
+
**kwargs,
|
| 73 |
+
):
|
| 74 |
+
super().__init__(**kwargs)
|
| 75 |
+
self.embed_dim = embed_dim
|
| 76 |
+
self.init_logit_bias = init_logit_bias
|
| 77 |
+
|
| 78 |
+
if isinstance(vision_config, dict):
|
| 79 |
+
self.vision_config = RaonVEVisionConfig(**vision_config)
|
| 80 |
+
elif vision_config is None:
|
| 81 |
+
self.vision_config = RaonVEVisionConfig()
|
| 82 |
+
else:
|
| 83 |
+
self.vision_config = vision_config
|
| 84 |
+
|
| 85 |
+
if isinstance(text_config, dict):
|
| 86 |
+
self.text_config = RaonVETextConfig(**text_config)
|
| 87 |
+
elif text_config is None:
|
| 88 |
+
self.text_config = RaonVETextConfig()
|
| 89 |
+
else:
|
| 90 |
+
self.text_config = text_config
|
| 91 |
+
|
| 92 |
+
def to_dict(self):
|
| 93 |
+
output = super().to_dict()
|
| 94 |
+
output["vision_config"] = self.vision_config.to_dict()
|
| 95 |
+
output["text_config"] = self.text_config.to_dict()
|
| 96 |
+
return output
|
modeling_raonve.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Raon-VisionEncoder model."""
|
| 2 |
+
|
| 3 |
+
import importlib
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch import nn
|
| 9 |
+
from transformers import PreTrainedModel
|
| 10 |
+
|
| 11 |
+
from .configuration_raonve import RaonVEConfig
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
_raon_repo_id = None
|
| 15 |
+
|
| 16 |
+
def set_repo_id(repo_id):
|
| 17 |
+
global _raon_repo_id
|
| 18 |
+
_raon_repo_id = repo_id
|
| 19 |
+
|
| 20 |
+
def _ensure_raon_package():
|
| 21 |
+
"""Import raon_vision_encoder, downloading from HF Hub if needed."""
|
| 22 |
+
try:
|
| 23 |
+
clip_mod = importlib.import_module("raon_vision_encoder.clip")
|
| 24 |
+
return clip_mod.CustomTextCLIP
|
| 25 |
+
except (ImportError, ModuleNotFoundError):
|
| 26 |
+
pass
|
| 27 |
+
|
| 28 |
+
from huggingface_hub import snapshot_download
|
| 29 |
+
repo_id = _raon_repo_id or "KRAFTON/Raon-VisionEncoder"
|
| 30 |
+
repo_dir = snapshot_download(repo_id, allow_patterns=["raon_vision_encoder/**"])
|
| 31 |
+
sys.path.insert(0, repo_dir)
|
| 32 |
+
|
| 33 |
+
for key in list(sys.modules.keys()):
|
| 34 |
+
if key.startswith("raon_vision_encoder"):
|
| 35 |
+
del sys.modules[key]
|
| 36 |
+
|
| 37 |
+
clip_mod = importlib.import_module("raon_vision_encoder.clip")
|
| 38 |
+
return clip_mod.CustomTextCLIP
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class RaonVEPreTrainedModel(PreTrainedModel):
|
| 42 |
+
config_class = RaonVEConfig
|
| 43 |
+
base_model_prefix = ""
|
| 44 |
+
supports_gradient_checkpointing = True
|
| 45 |
+
|
| 46 |
+
def _init_weights(self, module):
|
| 47 |
+
pass
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class RaonVEModel(RaonVEPreTrainedModel):
|
| 51 |
+
config_class = RaonVEConfig
|
| 52 |
+
|
| 53 |
+
@classmethod
|
| 54 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
| 55 |
+
set_repo_id(str(pretrained_model_name_or_path))
|
| 56 |
+
return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
|
| 57 |
+
|
| 58 |
+
def __init__(self, config: RaonVEConfig):
|
| 59 |
+
super().__init__(config)
|
| 60 |
+
|
| 61 |
+
vision_cfg = {
|
| 62 |
+
"image_size": config.vision_config.image_size,
|
| 63 |
+
"timm_model_name": config.vision_config.timm_model_name,
|
| 64 |
+
"timm_model_pretrained": config.vision_config.timm_model_pretrained,
|
| 65 |
+
"timm_pool": config.vision_config.timm_pool,
|
| 66 |
+
"timm_proj": config.vision_config.timm_proj,
|
| 67 |
+
}
|
| 68 |
+
text_cfg = {
|
| 69 |
+
"context_length": config.text_config.context_length,
|
| 70 |
+
"vocab_size": config.text_config.vocab_size,
|
| 71 |
+
"width": config.text_config.width,
|
| 72 |
+
"heads": config.text_config.heads,
|
| 73 |
+
"layers": config.text_config.layers,
|
| 74 |
+
"mlp_ratio": config.text_config.mlp_ratio,
|
| 75 |
+
"no_causal_mask": config.text_config.no_causal_mask,
|
| 76 |
+
"proj_bias": config.text_config.proj_bias,
|
| 77 |
+
"pool_type": config.text_config.pool_type,
|
| 78 |
+
"hf_tokenizer_name": config.text_config.hf_tokenizer_name,
|
| 79 |
+
"tokenizer_kwargs": config.text_config.tokenizer_kwargs,
|
| 80 |
+
"norm_kwargs": config.text_config.norm_kwargs,
|
| 81 |
+
"act_kwargs": config.text_config.act_kwargs,
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
CustomTextCLIP = _ensure_raon_package()
|
| 85 |
+
inner = CustomTextCLIP(
|
| 86 |
+
embed_dim=config.embed_dim,
|
| 87 |
+
vision_cfg=vision_cfg,
|
| 88 |
+
text_cfg=text_cfg,
|
| 89 |
+
init_logit_bias=config.init_logit_bias,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
self.visual = inner.visual
|
| 93 |
+
self.text = inner.text
|
| 94 |
+
self.logit_scale = inner.logit_scale
|
| 95 |
+
self.logit_bias = inner.logit_bias
|
| 96 |
+
|
| 97 |
+
# Enable NaFlex by default
|
| 98 |
+
self.visual._setup_1d_forward()
|
| 99 |
+
|
| 100 |
+
self.post_init()
|
| 101 |
+
|
| 102 |
+
def encode_image(self, pixel_values, pixel_attention_mask=None, spatial_shapes=None):
|
| 103 |
+
"""Encode images to normalized feature vectors [B, 1152].
|
| 104 |
+
Pass the output of processor(images=...) directly via **inputs.
|
| 105 |
+
"""
|
| 106 |
+
kwargs = {}
|
| 107 |
+
if pixel_attention_mask is not None:
|
| 108 |
+
kwargs["patch_valid_mask"] = pixel_attention_mask
|
| 109 |
+
if spatial_shapes is not None:
|
| 110 |
+
kwargs["spatial_shapes"] = spatial_shapes
|
| 111 |
+
features = self.visual(pixel_values, **kwargs) if kwargs else self.visual(pixel_values)
|
| 112 |
+
return F.normalize(features, dim=-1)
|
| 113 |
+
|
| 114 |
+
def encode_text(self, input_ids):
|
| 115 |
+
"""Encode text to normalized feature vectors [B, 1152].
|
| 116 |
+
Pass the output of processor(text=...) directly via **inputs.
|
| 117 |
+
"""
|
| 118 |
+
features = self.text(input_ids)
|
| 119 |
+
return F.normalize(features, dim=-1)
|
| 120 |
+
|
| 121 |
+
def forward(self, pixel_values=None, input_ids=None, pixel_attention_mask=None, spatial_shapes=None):
|
| 122 |
+
image_features = None
|
| 123 |
+
text_features = None
|
| 124 |
+
|
| 125 |
+
if pixel_values is not None:
|
| 126 |
+
image_features = self.encode_image(
|
| 127 |
+
pixel_values,
|
| 128 |
+
pixel_attention_mask=pixel_attention_mask,
|
| 129 |
+
spatial_shapes=spatial_shapes,
|
| 130 |
+
)
|
| 131 |
+
if input_ids is not None:
|
| 132 |
+
text_features = self.encode_text(input_ids)
|
| 133 |
+
|
| 134 |
+
output = {
|
| 135 |
+
"image_features": image_features,
|
| 136 |
+
"text_features": text_features,
|
| 137 |
+
"logit_scale": self.logit_scale,
|
| 138 |
+
"logit_bias": self.logit_bias,
|
| 139 |
+
}
|
| 140 |
+
return output
|
| 141 |
+
|
| 142 |
+
@staticmethod
|
| 143 |
+
def get_processor(pretrained_model_name_or_path, **kwargs):
|
| 144 |
+
"""Get the processor for this model."""
|
| 145 |
+
return RaonVEProcessor.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class RaonVEProcessor:
|
| 149 |
+
"""Image and text processor for Raon-VisionEncoder.
|
| 150 |
+
|
| 151 |
+
Preprocesses images into NaFlex patch sequences and tokenizes text.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
max_num_patches: Maximum number of patches per image (controls resolution).
|
| 155 |
+
Higher values preserve more detail. Default: 256.
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
DEFAULT_MAX_PATCHES = 256
|
| 159 |
+
|
| 160 |
+
def __init__(self, patch_size=16, mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711), tokenizer=None):
|
| 161 |
+
from torchvision import transforms as T
|
| 162 |
+
self.patch_size = patch_size
|
| 163 |
+
self.mean, self.std = mean, std
|
| 164 |
+
self.tokenizer = tokenizer
|
| 165 |
+
self._post = T.Compose([T.ToTensor(), T.Normalize(mean=list(mean), std=list(std))])
|
| 166 |
+
|
| 167 |
+
@classmethod
|
| 168 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
| 169 |
+
import json
|
| 170 |
+
from pathlib import Path as _Path
|
| 171 |
+
if _Path(pretrained_model_name_or_path).is_dir():
|
| 172 |
+
cfg_path = _Path(pretrained_model_name_or_path) / "config.json"
|
| 173 |
+
else:
|
| 174 |
+
from huggingface_hub import hf_hub_download
|
| 175 |
+
cfg_path = hf_hub_download(pretrained_model_name_or_path, "config.json")
|
| 176 |
+
with open(cfg_path) as f:
|
| 177 |
+
cfg = json.load(f)
|
| 178 |
+
v = cfg.get("vision_config", {}); t = cfg.get("text_config", {})
|
| 179 |
+
ps = 16
|
| 180 |
+
for part in v.get("timm_model_name", "").split("_"):
|
| 181 |
+
if part.startswith("patch") and part[5:].isdigit():
|
| 182 |
+
ps = int(part[5:]); break
|
| 183 |
+
tokenizer = None
|
| 184 |
+
if t.get("hf_tokenizer_name"):
|
| 185 |
+
_ensure_raon_package()
|
| 186 |
+
tok_mod = importlib.import_module("raon_vision_encoder.tokenizer")
|
| 187 |
+
tokenizer = tok_mod.HFTokenizer(
|
| 188 |
+
t["hf_tokenizer_name"], context_length=t.get("context_length", 64),
|
| 189 |
+
tokenizer_mode=t.get("tokenizer_mode"), **t.get("tokenizer_kwargs", {}),
|
| 190 |
+
)
|
| 191 |
+
return cls(patch_size=ps, tokenizer=tokenizer)
|
| 192 |
+
|
| 193 |
+
def __call__(self, images=None, text=None, max_num_patches=None, return_tensors="pt"):
|
| 194 |
+
"""Process images and/or text.
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
images: PIL Image or list of PIL Images.
|
| 198 |
+
text: String or list of strings.
|
| 199 |
+
max_num_patches: Resolution budget (default: 256). Higher = more detail.
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
Dict with 'pixel_values', 'pixel_attention_mask', 'spatial_shapes' for images
|
| 203 |
+
and/or 'input_ids' for text.
|
| 204 |
+
"""
|
| 205 |
+
from PIL import Image
|
| 206 |
+
result = {}
|
| 207 |
+
if images is not None:
|
| 208 |
+
mnp = max_num_patches or self.DEFAULT_MAX_PATCHES
|
| 209 |
+
_ensure_raon_package()
|
| 210 |
+
transform_mod = importlib.import_module("raon_vision_encoder.transform")
|
| 211 |
+
get_size = transform_mod.get_image_size_for_max_num_patches
|
| 212 |
+
imgs = [images] if isinstance(images, Image.Image) else images
|
| 213 |
+
ps = self.patch_size
|
| 214 |
+
all_p, all_m, all_s = [], [], []
|
| 215 |
+
for img in imgs:
|
| 216 |
+
img = img.convert("RGB")
|
| 217 |
+
w, h = img.size
|
| 218 |
+
th, tw = get_size(h, w, ps, mnp)
|
| 219 |
+
t = self._post(img.resize((tw, th), Image.BICUBIC))
|
| 220 |
+
gh, gw = th // ps, tw // ps
|
| 221 |
+
n = gh * gw
|
| 222 |
+
# [C, gh, ps, gw, ps] -> [gh, gw, C, ps, ps] -> [n, C*ps*ps]
|
| 223 |
+
patches = t.reshape(3, gh, ps, gw, ps).permute(1,3,0,2,4).reshape(n, 3*ps*ps)
|
| 224 |
+
padded = torch.zeros(mnp, ps*ps*3); padded[:n] = patches
|
| 225 |
+
mask = torch.zeros(mnp, dtype=torch.bool); mask[:n] = True
|
| 226 |
+
all_p.append(padded); all_m.append(mask)
|
| 227 |
+
all_s.append(torch.tensor([gh, gw]))
|
| 228 |
+
result["pixel_values"] = torch.stack(all_p)
|
| 229 |
+
result["pixel_attention_mask"] = torch.stack(all_m)
|
| 230 |
+
result["spatial_shapes"] = torch.stack(all_s)
|
| 231 |
+
if text is not None:
|
| 232 |
+
if self.tokenizer is None:
|
| 233 |
+
raise RuntimeError("Tokenizer not initialized.")
|
| 234 |
+
result["input_ids"] = self.tokenizer([text] if isinstance(text, str) else text)
|
| 235 |
+
return result
|
raon_vision_encoder/__init__.py
ADDED
|
File without changes
|
raon_vision_encoder/clip.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Originally from OpenCLIP (https://github.com/mlfoundations/open_clip)
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch import nn
|
| 10 |
+
from functools import partial
|
| 11 |
+
|
| 12 |
+
from .timm_model import TimmModel
|
| 13 |
+
from .transformer import (
|
| 14 |
+
LayerNormFp32,
|
| 15 |
+
LayerNorm,
|
| 16 |
+
QuickGELU,
|
| 17 |
+
TextTransformer,
|
| 18 |
+
text_global_pool,
|
| 19 |
+
)
|
| 20 |
+
from .utils import to_2tuple
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class CLIPVisionCfg:
|
| 25 |
+
layers: Union[Tuple[int, int, int, int], int] = 12
|
| 26 |
+
width: int = 768
|
| 27 |
+
head_width: int = 64
|
| 28 |
+
mlp_ratio: float = 4.0
|
| 29 |
+
patch_size: int = 16
|
| 30 |
+
image_size: Union[Tuple[int, int], int] = 224
|
| 31 |
+
|
| 32 |
+
ls_init_value: Optional[float] = None
|
| 33 |
+
patch_dropout: float = 0.0
|
| 34 |
+
attentional_pool: bool = False
|
| 35 |
+
attn_pooler_queries: int = 256
|
| 36 |
+
attn_pooler_heads: int = 8
|
| 37 |
+
no_ln_pre: bool = False
|
| 38 |
+
pos_embed_type: str = "learnable"
|
| 39 |
+
final_ln_after_pool: bool = False
|
| 40 |
+
pool_type: str = "tok"
|
| 41 |
+
output_tokens: bool = False
|
| 42 |
+
act_kwargs: Optional[dict] = None
|
| 43 |
+
norm_kwargs: Optional[dict] = None
|
| 44 |
+
|
| 45 |
+
block_type: Optional[str] = None
|
| 46 |
+
qk_norm: bool = False
|
| 47 |
+
scaled_cosine_attn: bool = False
|
| 48 |
+
scale_heads: bool = False
|
| 49 |
+
scale_attn_inner: bool = False
|
| 50 |
+
scale_attn: bool = False
|
| 51 |
+
scale_fc: bool = False
|
| 52 |
+
|
| 53 |
+
timm_model_name: Optional[str] = None
|
| 54 |
+
timm_model_pretrained: bool = False
|
| 55 |
+
timm_pool: str = "avg"
|
| 56 |
+
timm_proj: str = "linear"
|
| 57 |
+
timm_proj_bias: bool = False
|
| 58 |
+
timm_drop: float = 0.0
|
| 59 |
+
timm_drop_path: Optional[float] = None
|
| 60 |
+
timm_use_rope: bool = False
|
| 61 |
+
timm_rope_keep_ape: bool = False
|
| 62 |
+
timm_dynamic_img_size: bool = False
|
| 63 |
+
timm_norm_pre: bool = False
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@dataclass
|
| 67 |
+
class CLIPTextCfg:
|
| 68 |
+
context_length: int = 77
|
| 69 |
+
vocab_size: int = 49408
|
| 70 |
+
hf_tokenizer_name: Optional[str] = None
|
| 71 |
+
tokenizer_mode: Optional[str] = None
|
| 72 |
+
tokenizer_kwargs: Optional[dict] = None
|
| 73 |
+
|
| 74 |
+
width: int = 512
|
| 75 |
+
heads: int = 8
|
| 76 |
+
layers: int = 12
|
| 77 |
+
mlp_ratio: float = 4.0
|
| 78 |
+
ls_init_value: Optional[float] = None
|
| 79 |
+
embed_cls: bool = False
|
| 80 |
+
pad_id: int = 0
|
| 81 |
+
eos_id: int = 2
|
| 82 |
+
no_causal_mask: bool = False
|
| 83 |
+
final_ln_after_pool: bool = False
|
| 84 |
+
pool_type: str = "argmax"
|
| 85 |
+
proj_bias: bool = False
|
| 86 |
+
proj_type: str = "linear"
|
| 87 |
+
output_tokens: bool = False
|
| 88 |
+
act_kwargs: dict = None
|
| 89 |
+
norm_kwargs: dict = None
|
| 90 |
+
|
| 91 |
+
block_type: Optional[str] = None
|
| 92 |
+
qk_norm: bool = False
|
| 93 |
+
scaled_cosine_attn: bool = False
|
| 94 |
+
scale_heads: bool = False
|
| 95 |
+
scale_attn_inner: bool = False
|
| 96 |
+
scale_attn: bool = False
|
| 97 |
+
scale_fc: bool = False
|
| 98 |
+
|
| 99 |
+
hf_model_name: Optional[str] = None
|
| 100 |
+
hf_model_pretrained: bool = True
|
| 101 |
+
hf_proj_type: str = "mlp"
|
| 102 |
+
hf_pooler_type: str = "mean_pooler"
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def get_cast_dtype(precision: str):
|
| 106 |
+
cast_dtype = None
|
| 107 |
+
if precision == "bf16":
|
| 108 |
+
cast_dtype = torch.bfloat16
|
| 109 |
+
elif precision == "fp16":
|
| 110 |
+
cast_dtype = torch.float16
|
| 111 |
+
return cast_dtype
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def _build_vision_tower(
|
| 115 |
+
embed_dim: int,
|
| 116 |
+
vision_cfg: CLIPVisionCfg,
|
| 117 |
+
quick_gelu: bool = False,
|
| 118 |
+
cast_dtype: Optional[torch.dtype] = None,
|
| 119 |
+
):
|
| 120 |
+
if isinstance(vision_cfg, dict):
|
| 121 |
+
vision_cfg = CLIPVisionCfg(**vision_cfg)
|
| 122 |
+
|
| 123 |
+
if not vision_cfg.timm_model_name:
|
| 124 |
+
raise ValueError(
|
| 125 |
+
"Only TimmModel-based vision towers are supported in raon-vision-encoder. "
|
| 126 |
+
"Please set timm_model_name in vision_cfg."
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
visual = TimmModel(
|
| 130 |
+
vision_cfg.timm_model_name,
|
| 131 |
+
pretrained=vision_cfg.timm_model_pretrained,
|
| 132 |
+
pool=vision_cfg.timm_pool,
|
| 133 |
+
proj=vision_cfg.timm_proj,
|
| 134 |
+
proj_bias=vision_cfg.timm_proj_bias,
|
| 135 |
+
drop=vision_cfg.timm_drop,
|
| 136 |
+
drop_path=vision_cfg.timm_drop_path,
|
| 137 |
+
patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None,
|
| 138 |
+
init_values=vision_cfg.ls_init_value,
|
| 139 |
+
qk_norm=vision_cfg.qk_norm,
|
| 140 |
+
use_rope=vision_cfg.timm_use_rope,
|
| 141 |
+
rope_keep_ape=vision_cfg.timm_rope_keep_ape,
|
| 142 |
+
dynamic_img_size=vision_cfg.timm_dynamic_img_size,
|
| 143 |
+
norm_pre=vision_cfg.timm_norm_pre,
|
| 144 |
+
embed_dim=embed_dim,
|
| 145 |
+
image_size=vision_cfg.image_size,
|
| 146 |
+
output_tokens=vision_cfg.output_tokens,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
return visual
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def _build_text_tower(
|
| 153 |
+
embed_dim: int,
|
| 154 |
+
text_cfg: CLIPTextCfg,
|
| 155 |
+
quick_gelu: bool = False,
|
| 156 |
+
cast_dtype: Optional[torch.dtype] = None,
|
| 157 |
+
):
|
| 158 |
+
if isinstance(text_cfg, dict):
|
| 159 |
+
text_cfg = CLIPTextCfg(**text_cfg)
|
| 160 |
+
|
| 161 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
| 162 |
+
norm_layer = (
|
| 163 |
+
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
| 164 |
+
)
|
| 165 |
+
if text_cfg.norm_kwargs:
|
| 166 |
+
norm_layer = partial(norm_layer, **text_cfg.norm_kwargs)
|
| 167 |
+
if text_cfg.act_kwargs is not None:
|
| 168 |
+
act_layer = partial(act_layer, **text_cfg.act_kwargs)
|
| 169 |
+
|
| 170 |
+
text = TextTransformer(
|
| 171 |
+
context_length=text_cfg.context_length,
|
| 172 |
+
vocab_size=text_cfg.vocab_size,
|
| 173 |
+
width=text_cfg.width,
|
| 174 |
+
heads=text_cfg.heads,
|
| 175 |
+
layers=text_cfg.layers,
|
| 176 |
+
mlp_ratio=text_cfg.mlp_ratio,
|
| 177 |
+
ls_init_value=text_cfg.ls_init_value,
|
| 178 |
+
output_dim=embed_dim,
|
| 179 |
+
embed_cls=text_cfg.embed_cls,
|
| 180 |
+
no_causal_mask=text_cfg.no_causal_mask,
|
| 181 |
+
pad_id=text_cfg.pad_id,
|
| 182 |
+
eos_id=text_cfg.eos_id,
|
| 183 |
+
pool_type=text_cfg.pool_type,
|
| 184 |
+
proj_type=text_cfg.proj_type,
|
| 185 |
+
proj_bias=text_cfg.proj_bias,
|
| 186 |
+
output_tokens=text_cfg.output_tokens,
|
| 187 |
+
act_layer=act_layer,
|
| 188 |
+
norm_layer=norm_layer,
|
| 189 |
+
block_type=text_cfg.block_type,
|
| 190 |
+
qk_norm=text_cfg.qk_norm,
|
| 191 |
+
scaled_cosine_attn=text_cfg.scaled_cosine_attn,
|
| 192 |
+
scale_heads=text_cfg.scale_heads,
|
| 193 |
+
scale_attn_inner=text_cfg.scale_attn_inner,
|
| 194 |
+
scale_attn=text_cfg.scale_attn,
|
| 195 |
+
scale_fc=text_cfg.scale_fc,
|
| 196 |
+
)
|
| 197 |
+
return text
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class CustomTextCLIP(nn.Module):
|
| 201 |
+
output_dict: torch.jit.Final[bool]
|
| 202 |
+
|
| 203 |
+
def __init__(
|
| 204 |
+
self,
|
| 205 |
+
embed_dim: int,
|
| 206 |
+
vision_cfg: CLIPVisionCfg,
|
| 207 |
+
text_cfg: CLIPTextCfg,
|
| 208 |
+
quick_gelu: bool = False,
|
| 209 |
+
init_logit_scale: float = np.log(1 / 0.07),
|
| 210 |
+
init_logit_bias: Optional[float] = None,
|
| 211 |
+
nonscalar_logit_scale: bool = False,
|
| 212 |
+
cast_dtype: Optional[torch.dtype] = None,
|
| 213 |
+
output_dict: bool = False,
|
| 214 |
+
):
|
| 215 |
+
super().__init__()
|
| 216 |
+
self.output_dict = output_dict
|
| 217 |
+
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
| 218 |
+
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
|
| 219 |
+
self.context_length = self.text.context_length
|
| 220 |
+
self.vocab_size = self.text.vocab_size
|
| 221 |
+
|
| 222 |
+
lshape = [1] if nonscalar_logit_scale else []
|
| 223 |
+
self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
|
| 224 |
+
if init_logit_bias is not None:
|
| 225 |
+
self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
|
| 226 |
+
else:
|
| 227 |
+
self.logit_bias = None
|
| 228 |
+
|
| 229 |
+
def encode_image(
|
| 230 |
+
self, pixel_values, normalize: bool = False, pixel_attention_mask=None, spatial_shapes=None
|
| 231 |
+
):
|
| 232 |
+
kwargs = {}
|
| 233 |
+
if pixel_attention_mask is not None:
|
| 234 |
+
kwargs["patch_valid_mask"] = pixel_attention_mask
|
| 235 |
+
if spatial_shapes is not None:
|
| 236 |
+
kwargs["spatial_shapes"] = spatial_shapes
|
| 237 |
+
features = self.visual(pixel_values, **kwargs) if kwargs else self.visual(pixel_values)
|
| 238 |
+
return F.normalize(features, dim=-1) if normalize else features
|
| 239 |
+
|
| 240 |
+
def encode_text(self, input_ids, normalize: bool = False):
|
| 241 |
+
features = self.text(input_ids)
|
| 242 |
+
return F.normalize(features, dim=-1) if normalize else features
|
| 243 |
+
|
| 244 |
+
def get_logits(self, image, text):
|
| 245 |
+
image_features = self.encode_image(pixel_values=image, normalize=True)
|
| 246 |
+
text_features = self.encode_text(input_ids=text, normalize=True)
|
| 247 |
+
image_logits = self.logit_scale.exp() * image_features @ text_features.T
|
| 248 |
+
if self.logit_bias is not None:
|
| 249 |
+
image_logits += self.logit_bias
|
| 250 |
+
text_logits = image_logits.T
|
| 251 |
+
return image_logits, text_logits
|
| 252 |
+
|
| 253 |
+
def forward(
|
| 254 |
+
self, image=None, text=None, patch_valid_mask=None, spatial_shapes=None
|
| 255 |
+
):
|
| 256 |
+
image_features = (
|
| 257 |
+
self.encode_image(
|
| 258 |
+
pixel_values=image,
|
| 259 |
+
normalize=True,
|
| 260 |
+
pixel_attention_mask=patch_valid_mask,
|
| 261 |
+
spatial_shapes=spatial_shapes,
|
| 262 |
+
)
|
| 263 |
+
if image is not None
|
| 264 |
+
else None
|
| 265 |
+
)
|
| 266 |
+
text_features = (
|
| 267 |
+
self.encode_text(input_ids=text, normalize=True) if text is not None else None
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
if self.output_dict:
|
| 271 |
+
out_dict = {
|
| 272 |
+
"image_features": image_features,
|
| 273 |
+
"text_features": text_features,
|
| 274 |
+
"logit_scale": self.logit_scale.exp(),
|
| 275 |
+
}
|
| 276 |
+
if self.logit_bias is not None:
|
| 277 |
+
out_dict["logit_bias"] = self.logit_bias
|
| 278 |
+
return out_dict
|
| 279 |
+
|
| 280 |
+
if self.logit_bias is not None:
|
| 281 |
+
return (
|
| 282 |
+
image_features,
|
| 283 |
+
text_features,
|
| 284 |
+
self.logit_scale.exp(),
|
| 285 |
+
self.logit_bias,
|
| 286 |
+
)
|
| 287 |
+
return image_features, text_features, self.logit_scale.exp()
|
raon_vision_encoder/constants.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Originally from OpenCLIP (https://github.com/mlfoundations/open_clip)
|
| 2 |
+
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
| 3 |
+
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
|
| 4 |
+
INCEPTION_MEAN = (0.5, 0.5, 0.5)
|
| 5 |
+
INCEPTION_STD = (0.5, 0.5, 0.5)
|
raon_vision_encoder/timm_model.py
ADDED
|
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Originally from OpenCLIP (https://github.com/mlfoundations/open_clip)
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import types
|
| 5 |
+
from collections import OrderedDict
|
| 6 |
+
from typing import Optional, Tuple, Union
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
import timm
|
| 13 |
+
from timm.layers import RotAttentionPool2d
|
| 14 |
+
from timm.layers import AttentionPool2d as AbsAttentionPool2d
|
| 15 |
+
from timm.layers import Mlp, to_2tuple
|
| 16 |
+
from timm.layers import AttentionRope, RotaryEmbeddingCat
|
| 17 |
+
except ImportError:
|
| 18 |
+
timm = None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class TimmModel(nn.Module):
|
| 22 |
+
"""timm model adapter"""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
model_name: str,
|
| 27 |
+
embed_dim: int,
|
| 28 |
+
image_size: Union[int, Tuple[int, int]] = 224,
|
| 29 |
+
pool: str = "avg",
|
| 30 |
+
proj: str = "linear",
|
| 31 |
+
proj_bias: bool = False,
|
| 32 |
+
drop: float = 0.0,
|
| 33 |
+
drop_path: Optional[float] = None,
|
| 34 |
+
patch_drop: Optional[float] = None,
|
| 35 |
+
init_values: Optional[float] = None,
|
| 36 |
+
qk_norm: bool = False,
|
| 37 |
+
use_rope: bool = False,
|
| 38 |
+
rope_keep_ape: bool = False,
|
| 39 |
+
dynamic_img_size: bool = False,
|
| 40 |
+
norm_pre: bool = False,
|
| 41 |
+
pretrained: bool = False,
|
| 42 |
+
output_tokens: bool = False,
|
| 43 |
+
):
|
| 44 |
+
super().__init__()
|
| 45 |
+
if timm is None:
|
| 46 |
+
raise RuntimeError(
|
| 47 |
+
"Please install the latest timm (`pip install timm`) to use timm based models."
|
| 48 |
+
)
|
| 49 |
+
self.image_size = to_2tuple(image_size)
|
| 50 |
+
self.output_tokens = output_tokens
|
| 51 |
+
|
| 52 |
+
timm_kwargs = {}
|
| 53 |
+
if drop_path is not None:
|
| 54 |
+
timm_kwargs["drop_path_rate"] = drop_path
|
| 55 |
+
if patch_drop is not None:
|
| 56 |
+
timm_kwargs["patch_drop_rate"] = patch_drop
|
| 57 |
+
if init_values is not None:
|
| 58 |
+
timm_kwargs["init_values"] = init_values
|
| 59 |
+
if qk_norm:
|
| 60 |
+
timm_kwargs["qk_norm"] = True
|
| 61 |
+
if dynamic_img_size:
|
| 62 |
+
timm_kwargs["dynamic_img_size"] = True
|
| 63 |
+
if use_rope:
|
| 64 |
+
|
| 65 |
+
class _AttentionRopeNoPrefix(AttentionRope):
|
| 66 |
+
"""AttentionRope with num_prefix_tokens=0 for models without cls token."""
|
| 67 |
+
|
| 68 |
+
def __init__(self, *args, **kwargs):
|
| 69 |
+
kwargs["num_prefix_tokens"] = 0
|
| 70 |
+
super().__init__(*args, **kwargs)
|
| 71 |
+
|
| 72 |
+
timm_kwargs["attn_layer"] = _AttentionRopeNoPrefix
|
| 73 |
+
if not rope_keep_ape:
|
| 74 |
+
timm_kwargs["pos_embed"] = "none"
|
| 75 |
+
|
| 76 |
+
custom_pool = pool in ("abs_attn", "rot_attn")
|
| 77 |
+
if proj:
|
| 78 |
+
assert proj in ("linear", "mlp", "none")
|
| 79 |
+
extra_proj = proj in ("linear", "mlp")
|
| 80 |
+
if not extra_proj and not custom_pool:
|
| 81 |
+
proj_dim = 0 if proj == "none" else embed_dim
|
| 82 |
+
self.trunk = timm.create_model(
|
| 83 |
+
model_name,
|
| 84 |
+
num_classes=proj_dim,
|
| 85 |
+
global_pool=pool,
|
| 86 |
+
pretrained=pretrained,
|
| 87 |
+
**timm_kwargs,
|
| 88 |
+
)
|
| 89 |
+
prev_chs = embed_dim
|
| 90 |
+
else:
|
| 91 |
+
self.trunk = timm.create_model(
|
| 92 |
+
model_name,
|
| 93 |
+
pretrained=pretrained,
|
| 94 |
+
**timm_kwargs,
|
| 95 |
+
)
|
| 96 |
+
feat_size = self.trunk.default_cfg.get("pool_size", None)
|
| 97 |
+
feature_ndim = 1 if not feat_size else 2
|
| 98 |
+
if custom_pool:
|
| 99 |
+
assert feature_ndim == 2
|
| 100 |
+
self.trunk.reset_classifier(0, global_pool="")
|
| 101 |
+
else:
|
| 102 |
+
reset_kwargs = dict(global_pool=pool) if pool else {}
|
| 103 |
+
self.trunk.reset_classifier(0, **reset_kwargs)
|
| 104 |
+
prev_chs = self.trunk.num_features
|
| 105 |
+
|
| 106 |
+
head_layers = OrderedDict()
|
| 107 |
+
|
| 108 |
+
if pool == "abs_attn":
|
| 109 |
+
head_layers["pool"] = AbsAttentionPool2d(
|
| 110 |
+
prev_chs, feat_size=feat_size, out_features=embed_dim
|
| 111 |
+
)
|
| 112 |
+
prev_chs = embed_dim
|
| 113 |
+
elif pool == "rot_attn":
|
| 114 |
+
head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
|
| 115 |
+
prev_chs = embed_dim
|
| 116 |
+
|
| 117 |
+
if proj == "linear":
|
| 118 |
+
head_layers["drop"] = nn.Dropout(drop)
|
| 119 |
+
head_layers["proj"] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
|
| 120 |
+
elif proj == "mlp":
|
| 121 |
+
head_layers["mlp"] = Mlp(
|
| 122 |
+
prev_chs,
|
| 123 |
+
2 * embed_dim,
|
| 124 |
+
embed_dim,
|
| 125 |
+
drop=(drop, 0),
|
| 126 |
+
bias=(True, proj_bias),
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
self.head = nn.Sequential(head_layers)
|
| 130 |
+
|
| 131 |
+
if (
|
| 132 |
+
norm_pre
|
| 133 |
+
and hasattr(self.trunk, "norm_pre")
|
| 134 |
+
and isinstance(self.trunk.norm_pre, nn.Identity)
|
| 135 |
+
):
|
| 136 |
+
self.trunk.norm_pre = nn.LayerNorm(self.trunk.embed_dim)
|
| 137 |
+
logging.info(
|
| 138 |
+
f"Replaced norm_pre Identity with LayerNorm({self.trunk.embed_dim})"
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
self._has_rope = use_rope
|
| 142 |
+
if use_rope:
|
| 143 |
+
self._setup_rope()
|
| 144 |
+
|
| 145 |
+
def _setup_rope(self):
|
| 146 |
+
"""Inject 2D Rotary Position Embedding into the timm trunk."""
|
| 147 |
+
num_heads = self.trunk.blocks[0].attn.num_heads
|
| 148 |
+
head_dim = self.trunk.embed_dim // num_heads
|
| 149 |
+
|
| 150 |
+
self.trunk.patch_embed.strict_img_size = False
|
| 151 |
+
|
| 152 |
+
self.rope = RotaryEmbeddingCat(
|
| 153 |
+
dim=head_dim,
|
| 154 |
+
max_res=max(self.image_size),
|
| 155 |
+
in_pixels=True,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
def _block_forward_rope(block_self, x, rope=None, attn_mask=None):
|
| 159 |
+
x = x + block_self.drop_path1(
|
| 160 |
+
block_self.ls1(
|
| 161 |
+
block_self.attn(block_self.norm1(x), rope=rope, attn_mask=attn_mask)
|
| 162 |
+
)
|
| 163 |
+
)
|
| 164 |
+
x = x + block_self.drop_path2(
|
| 165 |
+
block_self.ls2(block_self.mlp(block_self.norm2(x)))
|
| 166 |
+
)
|
| 167 |
+
return x
|
| 168 |
+
|
| 169 |
+
for blk in self.trunk.blocks:
|
| 170 |
+
blk.forward = types.MethodType(_block_forward_rope, blk)
|
| 171 |
+
|
| 172 |
+
timm_model_ref = self
|
| 173 |
+
_num_prefix = getattr(self.trunk, "num_prefix_tokens", 0)
|
| 174 |
+
|
| 175 |
+
def _forward_features_rope(trunk_self, x, attn_mask=None):
|
| 176 |
+
from torch.utils.checkpoint import checkpoint
|
| 177 |
+
from timm.layers import resample_abs_pos_embed
|
| 178 |
+
|
| 179 |
+
ps = trunk_self.patch_embed.patch_size
|
| 180 |
+
grid_shape = [x.shape[2] // ps[0], x.shape[3] // ps[1]]
|
| 181 |
+
|
| 182 |
+
x = trunk_self.patch_embed(x)
|
| 183 |
+
if x.ndim == 4:
|
| 184 |
+
x = x.reshape(x.shape[0], -1, x.shape[-1])
|
| 185 |
+
if hasattr(trunk_self, "pos_embed") and trunk_self.pos_embed is not None:
|
| 186 |
+
if x.shape[1] != trunk_self.pos_embed.shape[1]:
|
| 187 |
+
x = x + resample_abs_pos_embed(
|
| 188 |
+
trunk_self.pos_embed, grid_shape, num_prefix_tokens=_num_prefix
|
| 189 |
+
)
|
| 190 |
+
else:
|
| 191 |
+
x = x + trunk_self.pos_embed
|
| 192 |
+
x = trunk_self.pos_drop(x)
|
| 193 |
+
x = trunk_self.norm_pre(x)
|
| 194 |
+
|
| 195 |
+
rot_pos_embed = timm_model_ref.rope.get_embed(shape=grid_shape)
|
| 196 |
+
|
| 197 |
+
_sdpa_mask = None
|
| 198 |
+
if attn_mask is not None:
|
| 199 |
+
_sdpa_mask = torch.zeros_like(attn_mask, dtype=x.dtype)
|
| 200 |
+
_sdpa_mask.masked_fill_(~attn_mask, float("-inf"))
|
| 201 |
+
_sdpa_mask = _sdpa_mask.unsqueeze(1).unsqueeze(2)
|
| 202 |
+
|
| 203 |
+
for blk in trunk_self.blocks:
|
| 204 |
+
if trunk_self.grad_checkpointing and not torch.jit.is_scripting():
|
| 205 |
+
x = checkpoint(
|
| 206 |
+
blk,
|
| 207 |
+
x,
|
| 208 |
+
rope=rot_pos_embed,
|
| 209 |
+
attn_mask=_sdpa_mask,
|
| 210 |
+
use_reentrant=False,
|
| 211 |
+
)
|
| 212 |
+
else:
|
| 213 |
+
x = blk(x, rope=rot_pos_embed, attn_mask=_sdpa_mask)
|
| 214 |
+
|
| 215 |
+
x = trunk_self.norm(x)
|
| 216 |
+
return x
|
| 217 |
+
|
| 218 |
+
self.trunk.forward_features = types.MethodType(
|
| 219 |
+
_forward_features_rope, self.trunk
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
def _setup_dynamic_pos_embed(self):
|
| 223 |
+
"""Patch forward_features for variable-resolution pos_embed interpolation (non-RoPE)."""
|
| 224 |
+
self.trunk.patch_embed.strict_img_size = False
|
| 225 |
+
_num_prefix = getattr(self.trunk, "num_prefix_tokens", 0)
|
| 226 |
+
|
| 227 |
+
def _forward_features_dynamic(trunk_self, x, patch_valid_mask=None):
|
| 228 |
+
from torch.utils.checkpoint import checkpoint
|
| 229 |
+
from timm.layers import resample_abs_pos_embed
|
| 230 |
+
|
| 231 |
+
ps = trunk_self.patch_embed.patch_size
|
| 232 |
+
grid_shape = [x.shape[2] // ps[0], x.shape[3] // ps[1]]
|
| 233 |
+
|
| 234 |
+
x = trunk_self.patch_embed(x)
|
| 235 |
+
if x.ndim == 4:
|
| 236 |
+
x = x.reshape(x.shape[0], -1, x.shape[-1])
|
| 237 |
+
if hasattr(trunk_self, "pos_embed") and trunk_self.pos_embed is not None:
|
| 238 |
+
if x.shape[1] != trunk_self.pos_embed.shape[1]:
|
| 239 |
+
x = x + resample_abs_pos_embed(
|
| 240 |
+
trunk_self.pos_embed, grid_shape, num_prefix_tokens=_num_prefix
|
| 241 |
+
)
|
| 242 |
+
else:
|
| 243 |
+
x = x + trunk_self.pos_embed
|
| 244 |
+
x = trunk_self.pos_drop(x)
|
| 245 |
+
x = trunk_self.norm_pre(x)
|
| 246 |
+
|
| 247 |
+
_sdpa_mask = None
|
| 248 |
+
if patch_valid_mask is not None:
|
| 249 |
+
_sdpa_mask = torch.zeros_like(patch_valid_mask, dtype=x.dtype)
|
| 250 |
+
_sdpa_mask.masked_fill_(~patch_valid_mask, float("-inf"))
|
| 251 |
+
_sdpa_mask = _sdpa_mask.unsqueeze(1).unsqueeze(2)
|
| 252 |
+
|
| 253 |
+
for blk in trunk_self.blocks:
|
| 254 |
+
if trunk_self.grad_checkpointing and not torch.jit.is_scripting():
|
| 255 |
+
if _sdpa_mask is not None:
|
| 256 |
+
x = checkpoint(
|
| 257 |
+
blk, x, attn_mask=_sdpa_mask, use_reentrant=False
|
| 258 |
+
)
|
| 259 |
+
else:
|
| 260 |
+
x = checkpoint(blk, x, use_reentrant=False)
|
| 261 |
+
else:
|
| 262 |
+
x = blk(x, attn_mask=_sdpa_mask)
|
| 263 |
+
|
| 264 |
+
x = trunk_self.norm(x)
|
| 265 |
+
return x
|
| 266 |
+
|
| 267 |
+
self.trunk.forward_features = types.MethodType(
|
| 268 |
+
_forward_features_dynamic, self.trunk
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
def _setup_1d_forward(self):
|
| 272 |
+
"""Patch forward_features for NaFlex 1D mode (SigLIP2 style)."""
|
| 273 |
+
_num_prefix = getattr(self.trunk, "num_prefix_tokens", 0)
|
| 274 |
+
|
| 275 |
+
def _forward_features_1d(
|
| 276 |
+
trunk_self, x, patch_valid_mask=None, spatial_shapes=None
|
| 277 |
+
):
|
| 278 |
+
from torch.utils.checkpoint import checkpoint
|
| 279 |
+
|
| 280 |
+
conv = trunk_self.patch_embed.proj
|
| 281 |
+
D = conv.weight.shape[0]
|
| 282 |
+
x = torch.nn.functional.linear(
|
| 283 |
+
x.to(conv.weight.dtype), conv.weight.reshape(D, -1), conv.bias
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
if (
|
| 287 |
+
hasattr(trunk_self, "pos_embed")
|
| 288 |
+
and trunk_self.pos_embed is not None
|
| 289 |
+
and spatial_shapes is not None
|
| 290 |
+
):
|
| 291 |
+
pos_embed = trunk_self.pos_embed
|
| 292 |
+
base_n = pos_embed.shape[1]
|
| 293 |
+
base_grid = int(base_n**0.5)
|
| 294 |
+
pos_2d = (
|
| 295 |
+
pos_embed.reshape(1, base_grid, base_grid, -1)
|
| 296 |
+
.permute(0, 3, 1, 2)
|
| 297 |
+
.float()
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
B, sl, D_emb = x.shape
|
| 301 |
+
pos_resized = torch.zeros(B, sl, D_emb, device=x.device, dtype=x.dtype)
|
| 302 |
+
|
| 303 |
+
for i in range(B):
|
| 304 |
+
gh, gw = spatial_shapes[i].tolist()
|
| 305 |
+
pe = torch.nn.functional.interpolate(
|
| 306 |
+
pos_2d, size=(gh, gw), mode="bilinear", align_corners=False
|
| 307 |
+
)
|
| 308 |
+
pe = pe.squeeze(0).permute(1, 2, 0).reshape(gh * gw, -1).to(x.dtype)
|
| 309 |
+
n_patches = gh * gw
|
| 310 |
+
pos_resized[i, :n_patches] = pe
|
| 311 |
+
if n_patches < sl:
|
| 312 |
+
pos_resized[i, n_patches:] = pe[0]
|
| 313 |
+
|
| 314 |
+
x = x + pos_resized
|
| 315 |
+
elif hasattr(trunk_self, "pos_embed") and trunk_self.pos_embed is not None:
|
| 316 |
+
x = x + trunk_self.pos_embed
|
| 317 |
+
|
| 318 |
+
x = trunk_self.pos_drop(x)
|
| 319 |
+
x = trunk_self.norm_pre(x)
|
| 320 |
+
|
| 321 |
+
_sdpa_mask = None
|
| 322 |
+
if patch_valid_mask is not None:
|
| 323 |
+
_sdpa_mask = torch.zeros_like(patch_valid_mask, dtype=x.dtype)
|
| 324 |
+
_sdpa_mask.masked_fill_(~patch_valid_mask, float("-inf"))
|
| 325 |
+
_sdpa_mask = _sdpa_mask.unsqueeze(1).unsqueeze(2)
|
| 326 |
+
|
| 327 |
+
for blk in trunk_self.blocks:
|
| 328 |
+
if trunk_self.grad_checkpointing and not torch.jit.is_scripting():
|
| 329 |
+
if _sdpa_mask is not None:
|
| 330 |
+
x = checkpoint(
|
| 331 |
+
blk, x, attn_mask=_sdpa_mask, use_reentrant=False
|
| 332 |
+
)
|
| 333 |
+
else:
|
| 334 |
+
x = checkpoint(blk, x, use_reentrant=False)
|
| 335 |
+
else:
|
| 336 |
+
x = blk(x, attn_mask=_sdpa_mask)
|
| 337 |
+
|
| 338 |
+
x = trunk_self.norm(x)
|
| 339 |
+
return x
|
| 340 |
+
|
| 341 |
+
self.trunk._forward_features_1d = types.MethodType(
|
| 342 |
+
_forward_features_1d, self.trunk
|
| 343 |
+
)
|
| 344 |
+
self._has_1d_forward = True
|
| 345 |
+
|
| 346 |
+
def forward_patch_features(self, x):
|
| 347 |
+
"""Forward pass returning per-patch features (before pooling/projection)."""
|
| 348 |
+
return self.trunk.forward_features(x)
|
| 349 |
+
|
| 350 |
+
def forward(self, x, patch_valid_mask=None, spatial_shapes=None):
|
| 351 |
+
if spatial_shapes is not None and getattr(self, "_has_1d_forward", False):
|
| 352 |
+
patch_features = self.trunk._forward_features_1d(
|
| 353 |
+
x, patch_valid_mask=patch_valid_mask, spatial_shapes=spatial_shapes
|
| 354 |
+
)
|
| 355 |
+
elif patch_valid_mask is not None and self._has_rope:
|
| 356 |
+
patch_features = self.trunk.forward_features(x, attn_mask=patch_valid_mask)
|
| 357 |
+
elif patch_valid_mask is not None:
|
| 358 |
+
patch_features = self.trunk.forward_features(
|
| 359 |
+
x, patch_valid_mask=patch_valid_mask
|
| 360 |
+
)
|
| 361 |
+
else:
|
| 362 |
+
patch_features = self.trunk.forward_features(x)
|
| 363 |
+
if patch_valid_mask is not None:
|
| 364 |
+
mask_f = patch_valid_mask.unsqueeze(-1).to(
|
| 365 |
+
patch_features.dtype
|
| 366 |
+
)
|
| 367 |
+
patch_features = patch_features * mask_f
|
| 368 |
+
self._cached_patch_features = patch_features
|
| 369 |
+
if (
|
| 370 |
+
patch_valid_mask is not None
|
| 371 |
+
and getattr(self.trunk, "global_pool", "") == "avg"
|
| 372 |
+
):
|
| 373 |
+
pooled = patch_features.sum(dim=1) / mask_f.sum(dim=1).clamp(min=1)
|
| 374 |
+
pooled = (
|
| 375 |
+
self.trunk.fc_norm(pooled) if hasattr(self.trunk, "fc_norm") else pooled
|
| 376 |
+
)
|
| 377 |
+
elif (
|
| 378 |
+
patch_valid_mask is not None
|
| 379 |
+
and getattr(self.trunk, "attn_pool", None) is not None
|
| 380 |
+
):
|
| 381 |
+
attn_mask = torch.zeros(
|
| 382 |
+
patch_valid_mask.shape,
|
| 383 |
+
dtype=patch_features.dtype,
|
| 384 |
+
device=patch_features.device,
|
| 385 |
+
)
|
| 386 |
+
attn_mask.masked_fill_(~patch_valid_mask.bool(), float("-inf"))
|
| 387 |
+
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)
|
| 388 |
+
pooled = self.trunk.attn_pool(patch_features, attn_mask=attn_mask)
|
| 389 |
+
pooled = (
|
| 390 |
+
self.trunk.fc_norm(pooled) if hasattr(self.trunk, "fc_norm") else pooled
|
| 391 |
+
)
|
| 392 |
+
else:
|
| 393 |
+
pooled = self.trunk.forward_head(patch_features)
|
| 394 |
+
pooled = self.head(pooled)
|
| 395 |
+
if self.output_tokens:
|
| 396 |
+
return pooled, patch_features
|
| 397 |
+
return pooled
|
raon_vision_encoder/tokenizer.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Originally from OpenCLIP (https://github.com/mlfoundations/open_clip)
|
| 2 |
+
|
| 3 |
+
import html
|
| 4 |
+
import os
|
| 5 |
+
import string
|
| 6 |
+
from typing import List, Optional, Union
|
| 7 |
+
import warnings
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
import ftfy
|
| 11 |
+
except ImportError:
|
| 12 |
+
ftfy = None
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 16 |
+
|
| 17 |
+
DEFAULT_CONTEXT_LENGTH = 77
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def basic_clean(text):
|
| 21 |
+
if ftfy is not None:
|
| 22 |
+
text = ftfy.fix_text(text)
|
| 23 |
+
else:
|
| 24 |
+
text
|
| 25 |
+
text = html.unescape(html.unescape(text))
|
| 26 |
+
return text.strip()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def whitespace_clean(text):
|
| 30 |
+
text = " ".join(text.split())
|
| 31 |
+
text = text.strip()
|
| 32 |
+
return text
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _clean_canonicalize(x):
|
| 36 |
+
return canonicalize_text(basic_clean(x))
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _clean_lower(x):
|
| 40 |
+
return whitespace_clean(basic_clean(x)).lower()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _clean_whitespace(x):
|
| 44 |
+
return whitespace_clean(basic_clean(x))
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_clean_fn(type: str):
|
| 48 |
+
if type == "canonicalize":
|
| 49 |
+
return _clean_canonicalize
|
| 50 |
+
elif type == "lower":
|
| 51 |
+
return _clean_lower
|
| 52 |
+
elif type == "whitespace":
|
| 53 |
+
return _clean_whitespace
|
| 54 |
+
else:
|
| 55 |
+
assert False, f"Invalid clean function ({type})."
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def canonicalize_text(
|
| 59 |
+
text,
|
| 60 |
+
*,
|
| 61 |
+
keep_punctuation_exact_string=None,
|
| 62 |
+
trans_punctuation: dict = str.maketrans("", "", string.punctuation),
|
| 63 |
+
):
|
| 64 |
+
"""Returns canonicalized `text` (lowercase and punctuation removed)."""
|
| 65 |
+
text = text.replace("_", " ")
|
| 66 |
+
if keep_punctuation_exact_string:
|
| 67 |
+
text = keep_punctuation_exact_string.join(
|
| 68 |
+
part.translate(trans_punctuation)
|
| 69 |
+
for part in text.split(keep_punctuation_exact_string)
|
| 70 |
+
)
|
| 71 |
+
else:
|
| 72 |
+
text = text.translate(trans_punctuation)
|
| 73 |
+
text = text.lower()
|
| 74 |
+
text = " ".join(text.split())
|
| 75 |
+
return text.strip()
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class HFTokenizer:
|
| 79 |
+
"""HuggingFace tokenizer wrapper with support for custom tokenization modes"""
|
| 80 |
+
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
tokenizer_name: str,
|
| 84 |
+
context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,
|
| 85 |
+
clean: str = "whitespace",
|
| 86 |
+
strip_sep_token: bool = False,
|
| 87 |
+
language: Optional[str] = None,
|
| 88 |
+
cache_dir: Optional[str] = None,
|
| 89 |
+
tokenizer_mode: Optional[str] = None,
|
| 90 |
+
**kwargs,
|
| 91 |
+
):
|
| 92 |
+
self.tokenizer_mode = tokenizer_mode or ""
|
| 93 |
+
self.context_length = context_length
|
| 94 |
+
self.clean_fn = get_clean_fn(clean)
|
| 95 |
+
self.strip_sep_token = strip_sep_token
|
| 96 |
+
|
| 97 |
+
from transformers import AutoTokenizer
|
| 98 |
+
|
| 99 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 100 |
+
tokenizer_name, cache_dir=cache_dir, **kwargs
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
set_lang_fn = getattr(self.tokenizer, "set_src_lang_special_tokens", None)
|
| 104 |
+
if callable(set_lang_fn):
|
| 105 |
+
self.set_lang_fn = set_lang_fn
|
| 106 |
+
if language is not None:
|
| 107 |
+
self.set_language(language)
|
| 108 |
+
|
| 109 |
+
def save_pretrained(self, dest):
|
| 110 |
+
self.tokenizer.save_pretrained(dest)
|
| 111 |
+
|
| 112 |
+
def __call__(
|
| 113 |
+
self, texts: Union[str, List[str]], context_length: Optional[int] = None
|
| 114 |
+
) -> torch.Tensor:
|
| 115 |
+
if isinstance(texts, str):
|
| 116 |
+
texts = [texts]
|
| 117 |
+
|
| 118 |
+
context_length = context_length or self.context_length
|
| 119 |
+
assert context_length, (
|
| 120 |
+
"Please set a valid context length in class init or call."
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
texts = [self.clean_fn(text) for text in texts]
|
| 124 |
+
|
| 125 |
+
if self.tokenizer_mode == "clips":
|
| 126 |
+
return self._clips_tokenize(texts, context_length)
|
| 127 |
+
else:
|
| 128 |
+
output = self.tokenizer(
|
| 129 |
+
texts,
|
| 130 |
+
return_tensors="pt",
|
| 131 |
+
max_length=context_length,
|
| 132 |
+
padding="max_length",
|
| 133 |
+
truncation=True,
|
| 134 |
+
)
|
| 135 |
+
input_ids = output.input_ids
|
| 136 |
+
|
| 137 |
+
if self.strip_sep_token:
|
| 138 |
+
input_ids = torch.where(
|
| 139 |
+
input_ids == self.tokenizer.sep_token_id,
|
| 140 |
+
torch.zeros_like(input_ids),
|
| 141 |
+
input_ids,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
return input_ids
|
| 145 |
+
|
| 146 |
+
def set_language(self, src_lang):
|
| 147 |
+
if hasattr(self, "set_lang_fn"):
|
| 148 |
+
self.set_lang_fn(src_lang)
|
| 149 |
+
else:
|
| 150 |
+
warnings.warn("Cannot set language for the tokenizer.")
|
| 151 |
+
|
| 152 |
+
def _clips_tokenize(self, texts: List[str], context_length: int) -> torch.Tensor:
|
| 153 |
+
encoded_outputs = self.tokenizer(
|
| 154 |
+
texts,
|
| 155 |
+
add_special_tokens=False,
|
| 156 |
+
padding=False,
|
| 157 |
+
truncation=False,
|
| 158 |
+
return_tensors=None,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
encoded = []
|
| 162 |
+
for tokens in encoded_outputs["input_ids"]:
|
| 163 |
+
tokens = tokens[: context_length - 3]
|
| 164 |
+
tokens = (
|
| 165 |
+
[self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
|
| 166 |
+
)
|
| 167 |
+
encoded.append(tokens)
|
| 168 |
+
|
| 169 |
+
result = torch.zeros(len(encoded), context_length, dtype=torch.long)
|
| 170 |
+
for i, tokens in enumerate(encoded):
|
| 171 |
+
padded_tokens = self._pad_and_add_class_token(
|
| 172 |
+
tokens,
|
| 173 |
+
max_length=context_length,
|
| 174 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 175 |
+
cls_token_id=self.tokenizer.cls_token_id,
|
| 176 |
+
)
|
| 177 |
+
result[i, : len(padded_tokens)] = torch.tensor(padded_tokens)
|
| 178 |
+
|
| 179 |
+
return result
|
| 180 |
+
|
| 181 |
+
def _pad_and_add_class_token(
|
| 182 |
+
self,
|
| 183 |
+
tokens: List[int],
|
| 184 |
+
max_length: int,
|
| 185 |
+
pad_token_id: int = 0,
|
| 186 |
+
cls_token_id: int = 101,
|
| 187 |
+
) -> List[int]:
|
| 188 |
+
if len(tokens) > max_length - 1:
|
| 189 |
+
tokens = tokens[: max_length - 1]
|
| 190 |
+
if len(tokens) < max_length - 1:
|
| 191 |
+
tokens = tokens + [pad_token_id] * (max_length - 1 - len(tokens))
|
| 192 |
+
tokens = tokens + [cls_token_id]
|
| 193 |
+
return tokens
|
raon_vision_encoder/transform.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Originally from OpenCLIP (https://github.com/mlfoundations/open_clip)
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def get_image_size_for_max_num_patches(
|
| 7 |
+
image_height, image_width, patch_size, max_num_patches
|
| 8 |
+
):
|
| 9 |
+
"""Find target image size preserving aspect ratio within patch budget.
|
| 10 |
+
|
| 11 |
+
Uses binary search to find the optimal scale such that
|
| 12 |
+
ceil(h*scale/ps)*ceil(w*scale/ps) <= max_num_patches.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
image_height: Original image height.
|
| 16 |
+
image_width: Original image width.
|
| 17 |
+
patch_size: Patch size (int).
|
| 18 |
+
max_num_patches: Maximum number of patches allowed.
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
(target_h, target_w) both multiples of patch_size.
|
| 22 |
+
"""
|
| 23 |
+
scale_min, scale_max = 1e-6, 100.0
|
| 24 |
+
eps = 1e-5
|
| 25 |
+
while (scale_max - scale_min) >= eps:
|
| 26 |
+
scale = (scale_min + scale_max) / 2
|
| 27 |
+
target_h = max(
|
| 28 |
+
patch_size, int(math.ceil(image_height * scale / patch_size) * patch_size)
|
| 29 |
+
)
|
| 30 |
+
target_w = max(
|
| 31 |
+
patch_size, int(math.ceil(image_width * scale / patch_size) * patch_size)
|
| 32 |
+
)
|
| 33 |
+
num_patches = (target_h // patch_size) * (target_w // patch_size)
|
| 34 |
+
if num_patches <= max_num_patches:
|
| 35 |
+
scale_min = scale
|
| 36 |
+
else:
|
| 37 |
+
scale_max = scale
|
| 38 |
+
target_h = max(
|
| 39 |
+
patch_size, int(math.ceil(image_height * scale_min / patch_size) * patch_size)
|
| 40 |
+
)
|
| 41 |
+
target_w = max(
|
| 42 |
+
patch_size, int(math.ceil(image_width * scale_min / patch_size) * patch_size)
|
| 43 |
+
)
|
| 44 |
+
return target_h, target_w
|
raon_vision_encoder/transformer.py
ADDED
|
@@ -0,0 +1,627 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Originally from OpenCLIP (https://github.com/mlfoundations/open_clip)
|
| 2 |
+
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
import math
|
| 5 |
+
from typing import Callable, Optional, Type, Union
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
from torch.utils.checkpoint import checkpoint
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class LayerNormFp32(nn.LayerNorm):
|
| 14 |
+
"""Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
|
| 15 |
+
|
| 16 |
+
def forward(self, x: torch.Tensor):
|
| 17 |
+
orig_type = x.dtype
|
| 18 |
+
x = F.layer_norm(
|
| 19 |
+
x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps
|
| 20 |
+
)
|
| 21 |
+
return x.to(orig_type)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class LayerNorm(nn.LayerNorm):
|
| 25 |
+
"""Subclass torch's LayerNorm (with cast back to input dtype)."""
|
| 26 |
+
|
| 27 |
+
def forward(self, x: torch.Tensor):
|
| 28 |
+
orig_type = x.dtype
|
| 29 |
+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
| 30 |
+
return x.to(orig_type)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class QuickGELU(nn.Module):
|
| 34 |
+
# NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
|
| 35 |
+
def forward(self, x: torch.Tensor):
|
| 36 |
+
return x * torch.sigmoid(1.702 * x)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class LayerScale(nn.Module):
|
| 40 |
+
def __init__(self, dim, init_values=1e-5, inplace=False):
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.inplace = inplace
|
| 43 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class Attention(nn.Module):
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
dim: int,
|
| 53 |
+
num_heads: int = 8,
|
| 54 |
+
qkv_bias: bool = True,
|
| 55 |
+
qk_norm: bool = False,
|
| 56 |
+
scaled_cosine: bool = False,
|
| 57 |
+
scale_heads: bool = False,
|
| 58 |
+
inner_norm: bool = False,
|
| 59 |
+
logit_scale_max: float = math.log(1.0 / 0.01),
|
| 60 |
+
norm_layer: Type[nn.Module] = LayerNormFp32,
|
| 61 |
+
attn_drop: float = 0.0,
|
| 62 |
+
proj_drop: float = 0.0,
|
| 63 |
+
):
|
| 64 |
+
super().__init__()
|
| 65 |
+
assert not (scaled_cosine and qk_norm), (
|
| 66 |
+
"Cannot activate both scaled cosine and QK normalization"
|
| 67 |
+
)
|
| 68 |
+
self.scaled_cosine = scaled_cosine
|
| 69 |
+
self.scale_heads = scale_heads
|
| 70 |
+
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
| 71 |
+
self.num_heads = num_heads
|
| 72 |
+
self.head_dim = dim // num_heads
|
| 73 |
+
self.scale = self.head_dim**-0.5
|
| 74 |
+
self.logit_scale_max = logit_scale_max
|
| 75 |
+
self.use_fsdpa = hasattr(nn.functional, "scaled_dot_product_attention")
|
| 76 |
+
|
| 77 |
+
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
|
| 78 |
+
if qkv_bias:
|
| 79 |
+
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
|
| 80 |
+
else:
|
| 81 |
+
self.in_proj_bias = None
|
| 82 |
+
|
| 83 |
+
if qk_norm:
|
| 84 |
+
self.ln_q = norm_layer(self.head_dim)
|
| 85 |
+
self.ln_k = norm_layer(self.head_dim)
|
| 86 |
+
else:
|
| 87 |
+
self.ln_q = nn.Identity()
|
| 88 |
+
self.ln_k = nn.Identity()
|
| 89 |
+
|
| 90 |
+
if self.scaled_cosine:
|
| 91 |
+
self.logit_scale = nn.Parameter(
|
| 92 |
+
torch.log(10 * torch.ones((num_heads, 1, 1)))
|
| 93 |
+
)
|
| 94 |
+
else:
|
| 95 |
+
self.logit_scale = None
|
| 96 |
+
|
| 97 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 98 |
+
|
| 99 |
+
if self.scale_heads:
|
| 100 |
+
self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
|
| 101 |
+
else:
|
| 102 |
+
self.head_scale = None
|
| 103 |
+
|
| 104 |
+
if inner_norm:
|
| 105 |
+
self.ln_inner = norm_layer(dim)
|
| 106 |
+
else:
|
| 107 |
+
self.ln_inner = nn.Identity()
|
| 108 |
+
|
| 109 |
+
self.out_proj = nn.Linear(dim, dim)
|
| 110 |
+
self.out_drop = nn.Dropout(proj_drop)
|
| 111 |
+
|
| 112 |
+
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
|
| 113 |
+
N, L, C = x.shape
|
| 114 |
+
q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
|
| 115 |
+
q = q.reshape(N, L, self.num_heads, -1).transpose(1, 2)
|
| 116 |
+
k = k.reshape(N, L, self.num_heads, -1).transpose(1, 2)
|
| 117 |
+
v = v.reshape(N, L, self.num_heads, -1).transpose(1, 2)
|
| 118 |
+
|
| 119 |
+
if attn_mask is not None:
|
| 120 |
+
if attn_mask.ndim == 3:
|
| 121 |
+
attn_mask = attn_mask.reshape(N, self.num_heads, L, L)
|
| 122 |
+
if attn_mask.dtype == torch.bool:
|
| 123 |
+
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
|
| 124 |
+
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
|
| 125 |
+
attn_mask = new_attn_mask
|
| 126 |
+
else:
|
| 127 |
+
attn_mask = attn_mask.to(dtype=q.dtype)
|
| 128 |
+
|
| 129 |
+
if self.logit_scale is not None:
|
| 130 |
+
attn = torch.bmm(
|
| 131 |
+
F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)
|
| 132 |
+
)
|
| 133 |
+
logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
|
| 134 |
+
attn = attn * logit_scale
|
| 135 |
+
if attn_mask is not None:
|
| 136 |
+
attn = attn + attn_mask
|
| 137 |
+
attn = attn.softmax(dim=-1)
|
| 138 |
+
attn = self.attn_drop(attn)
|
| 139 |
+
x = torch.bmm(attn, v)
|
| 140 |
+
else:
|
| 141 |
+
q = self.ln_q(q)
|
| 142 |
+
k = self.ln_k(k)
|
| 143 |
+
if self.use_fsdpa:
|
| 144 |
+
x = F.scaled_dot_product_attention(
|
| 145 |
+
q,
|
| 146 |
+
k,
|
| 147 |
+
v,
|
| 148 |
+
attn_mask=attn_mask,
|
| 149 |
+
dropout_p=self.attn_drop.p if self.training else 0.0,
|
| 150 |
+
)
|
| 151 |
+
else:
|
| 152 |
+
q = q * self.scale
|
| 153 |
+
attn = torch.bmm(q, k.transpose(-1, -2))
|
| 154 |
+
if attn_mask is not None:
|
| 155 |
+
attn += attn_mask
|
| 156 |
+
attn = attn.softmax(dim=-1)
|
| 157 |
+
attn = self.attn_drop(attn)
|
| 158 |
+
x = torch.bmm(attn, v)
|
| 159 |
+
|
| 160 |
+
if self.head_scale is not None:
|
| 161 |
+
x = x * self.head_scale
|
| 162 |
+
x = x.transpose(1, 2).reshape(N, L, C)
|
| 163 |
+
x = self.ln_inner(x)
|
| 164 |
+
x = self.out_proj(x)
|
| 165 |
+
x = self.out_drop(x)
|
| 166 |
+
return x
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class ResidualAttentionBlock(nn.Module):
|
| 170 |
+
def __init__(
|
| 171 |
+
self,
|
| 172 |
+
d_model: int,
|
| 173 |
+
n_head: int,
|
| 174 |
+
mlp_ratio: float = 4.0,
|
| 175 |
+
ls_init_value: float = None,
|
| 176 |
+
act_layer: Callable = nn.GELU,
|
| 177 |
+
norm_layer: Callable = LayerNorm,
|
| 178 |
+
is_cross_attention: bool = False,
|
| 179 |
+
batch_first: bool = True,
|
| 180 |
+
):
|
| 181 |
+
super().__init__()
|
| 182 |
+
|
| 183 |
+
self.ln_1 = norm_layer(d_model)
|
| 184 |
+
self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=batch_first)
|
| 185 |
+
self.ls_1 = (
|
| 186 |
+
LayerScale(d_model, ls_init_value)
|
| 187 |
+
if ls_init_value is not None
|
| 188 |
+
else nn.Identity()
|
| 189 |
+
)
|
| 190 |
+
if is_cross_attention:
|
| 191 |
+
self.ln_1_kv = norm_layer(d_model)
|
| 192 |
+
|
| 193 |
+
self.ln_2 = norm_layer(d_model)
|
| 194 |
+
mlp_width = int(d_model * mlp_ratio)
|
| 195 |
+
self.mlp = nn.Sequential(
|
| 196 |
+
OrderedDict(
|
| 197 |
+
[
|
| 198 |
+
("c_fc", nn.Linear(d_model, mlp_width)),
|
| 199 |
+
("gelu", act_layer()),
|
| 200 |
+
("c_proj", nn.Linear(mlp_width, d_model)),
|
| 201 |
+
]
|
| 202 |
+
)
|
| 203 |
+
)
|
| 204 |
+
self.ls_2 = (
|
| 205 |
+
LayerScale(d_model, ls_init_value)
|
| 206 |
+
if ls_init_value is not None
|
| 207 |
+
else nn.Identity()
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
def get_weight_dtype(self) -> torch.dtype:
|
| 211 |
+
if hasattr(self.mlp.c_fc, "int8_original_dtype"):
|
| 212 |
+
return self.mlp.c_fc.int8_original_dtype
|
| 213 |
+
return self.mlp.c_fc.weight.dtype
|
| 214 |
+
|
| 215 |
+
def attention(
|
| 216 |
+
self,
|
| 217 |
+
q_x: torch.Tensor,
|
| 218 |
+
k_x: Optional[torch.Tensor] = None,
|
| 219 |
+
v_x: Optional[torch.Tensor] = None,
|
| 220 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 221 |
+
key_padding_mask: Optional[torch.Tensor] = None,
|
| 222 |
+
):
|
| 223 |
+
k_x = k_x if k_x is not None else q_x
|
| 224 |
+
v_x = v_x if v_x is not None else q_x
|
| 225 |
+
|
| 226 |
+
attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
|
| 227 |
+
return self.attn(
|
| 228 |
+
q_x,
|
| 229 |
+
k_x,
|
| 230 |
+
v_x,
|
| 231 |
+
need_weights=False,
|
| 232 |
+
attn_mask=attn_mask,
|
| 233 |
+
key_padding_mask=key_padding_mask,
|
| 234 |
+
)[0]
|
| 235 |
+
|
| 236 |
+
def forward(
|
| 237 |
+
self,
|
| 238 |
+
q_x: torch.Tensor,
|
| 239 |
+
k_x: Optional[torch.Tensor] = None,
|
| 240 |
+
v_x: Optional[torch.Tensor] = None,
|
| 241 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 242 |
+
key_padding_mask: Optional[torch.Tensor] = None,
|
| 243 |
+
):
|
| 244 |
+
k_x = (
|
| 245 |
+
self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
|
| 246 |
+
)
|
| 247 |
+
v_x = (
|
| 248 |
+
self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
|
| 249 |
+
)
|
| 250 |
+
x = q_x + self.ls_1(
|
| 251 |
+
self.attention(
|
| 252 |
+
q_x=self.ln_1(q_x),
|
| 253 |
+
k_x=k_x,
|
| 254 |
+
v_x=v_x,
|
| 255 |
+
attn_mask=attn_mask,
|
| 256 |
+
key_padding_mask=key_padding_mask,
|
| 257 |
+
)
|
| 258 |
+
)
|
| 259 |
+
x = x + self.ls_2(self.mlp(self.ln_2(x)))
|
| 260 |
+
return x
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class CustomResidualAttentionBlock(nn.Module):
|
| 264 |
+
def __init__(
|
| 265 |
+
self,
|
| 266 |
+
d_model: int,
|
| 267 |
+
n_head: int,
|
| 268 |
+
mlp_ratio: float = 4.0,
|
| 269 |
+
ls_init_value: float = None,
|
| 270 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
| 271 |
+
norm_layer: Type[nn.Module] = LayerNorm,
|
| 272 |
+
qk_norm: bool = False,
|
| 273 |
+
scale_cosine_attn: bool = False,
|
| 274 |
+
scale_heads: bool = False,
|
| 275 |
+
scale_attn_inner: bool = False,
|
| 276 |
+
scale_attn: bool = False,
|
| 277 |
+
scale_fc: bool = False,
|
| 278 |
+
batch_first: bool = True,
|
| 279 |
+
):
|
| 280 |
+
super().__init__()
|
| 281 |
+
assert batch_first, "batch_first must be True for CustomResidualAttentionBlock"
|
| 282 |
+
|
| 283 |
+
self.ln_1 = norm_layer(d_model)
|
| 284 |
+
self.attn = Attention(
|
| 285 |
+
d_model,
|
| 286 |
+
n_head,
|
| 287 |
+
qk_norm=qk_norm,
|
| 288 |
+
scaled_cosine=scale_cosine_attn,
|
| 289 |
+
scale_heads=scale_heads,
|
| 290 |
+
inner_norm=scale_attn_inner,
|
| 291 |
+
norm_layer=norm_layer,
|
| 292 |
+
)
|
| 293 |
+
self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
|
| 294 |
+
self.ls_1 = (
|
| 295 |
+
LayerScale(d_model, ls_init_value)
|
| 296 |
+
if ls_init_value is not None
|
| 297 |
+
else nn.Identity()
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
self.ln_2 = norm_layer(d_model)
|
| 301 |
+
mlp_width = int(d_model * mlp_ratio)
|
| 302 |
+
self.mlp = nn.Sequential(
|
| 303 |
+
OrderedDict(
|
| 304 |
+
[
|
| 305 |
+
("c_fc", nn.Linear(d_model, mlp_width)),
|
| 306 |
+
("gelu", act_layer()),
|
| 307 |
+
("ln", norm_layer(mlp_width) if scale_fc else nn.Identity()),
|
| 308 |
+
("c_proj", nn.Linear(mlp_width, d_model)),
|
| 309 |
+
]
|
| 310 |
+
)
|
| 311 |
+
)
|
| 312 |
+
self.ls_2 = (
|
| 313 |
+
LayerScale(d_model, ls_init_value)
|
| 314 |
+
if ls_init_value is not None
|
| 315 |
+
else nn.Identity()
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
def get_weight_dtype(self) -> torch.dtype:
|
| 319 |
+
if hasattr(self.mlp.c_fc, "int8_original_dtype"):
|
| 320 |
+
return self.mlp.c_fc.int8_original_dtype
|
| 321 |
+
return self.mlp.c_fc.weight.dtype
|
| 322 |
+
|
| 323 |
+
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
| 324 |
+
x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask)))
|
| 325 |
+
x = x + self.ls_2(self.mlp(self.ln_2(x)))
|
| 326 |
+
return x
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class Transformer(nn.Module):
|
| 330 |
+
def __init__(
|
| 331 |
+
self,
|
| 332 |
+
width: int,
|
| 333 |
+
layers: int,
|
| 334 |
+
heads: int,
|
| 335 |
+
mlp_ratio: float = 4.0,
|
| 336 |
+
ls_init_value: float = None,
|
| 337 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
| 338 |
+
norm_layer: Type[nn.Module] = LayerNorm,
|
| 339 |
+
batch_first: bool = True,
|
| 340 |
+
block_type: Optional[str] = None,
|
| 341 |
+
qk_norm: bool = False,
|
| 342 |
+
scaled_cosine_attn: bool = False,
|
| 343 |
+
scale_heads: bool = False,
|
| 344 |
+
scale_attn_inner: bool = False,
|
| 345 |
+
scale_attn: bool = False,
|
| 346 |
+
scale_fc: bool = False,
|
| 347 |
+
):
|
| 348 |
+
super().__init__()
|
| 349 |
+
self.width = width
|
| 350 |
+
self.layers = layers
|
| 351 |
+
self.batch_first = batch_first
|
| 352 |
+
self.grad_checkpointing = False
|
| 353 |
+
|
| 354 |
+
if block_type is None:
|
| 355 |
+
if any(
|
| 356 |
+
[
|
| 357 |
+
qk_norm,
|
| 358 |
+
scaled_cosine_attn,
|
| 359 |
+
scale_heads,
|
| 360 |
+
scale_attn_inner,
|
| 361 |
+
scale_attn,
|
| 362 |
+
scale_fc,
|
| 363 |
+
]
|
| 364 |
+
):
|
| 365 |
+
block_type = "custom"
|
| 366 |
+
else:
|
| 367 |
+
block_type = "default"
|
| 368 |
+
|
| 369 |
+
if block_type == "custom":
|
| 370 |
+
self.resblocks = nn.ModuleList(
|
| 371 |
+
[
|
| 372 |
+
CustomResidualAttentionBlock(
|
| 373 |
+
width,
|
| 374 |
+
heads,
|
| 375 |
+
mlp_ratio,
|
| 376 |
+
ls_init_value=ls_init_value,
|
| 377 |
+
act_layer=act_layer,
|
| 378 |
+
norm_layer=norm_layer,
|
| 379 |
+
qk_norm=qk_norm,
|
| 380 |
+
scale_cosine_attn=scaled_cosine_attn,
|
| 381 |
+
scale_heads=scale_heads,
|
| 382 |
+
scale_attn_inner=scale_attn_inner,
|
| 383 |
+
scale_attn=scale_attn,
|
| 384 |
+
scale_fc=scale_fc,
|
| 385 |
+
batch_first=batch_first,
|
| 386 |
+
)
|
| 387 |
+
for _ in range(layers)
|
| 388 |
+
]
|
| 389 |
+
)
|
| 390 |
+
else:
|
| 391 |
+
self.resblocks = nn.ModuleList(
|
| 392 |
+
[
|
| 393 |
+
ResidualAttentionBlock(
|
| 394 |
+
width,
|
| 395 |
+
heads,
|
| 396 |
+
mlp_ratio,
|
| 397 |
+
ls_init_value=ls_init_value,
|
| 398 |
+
act_layer=act_layer,
|
| 399 |
+
norm_layer=norm_layer,
|
| 400 |
+
batch_first=batch_first,
|
| 401 |
+
)
|
| 402 |
+
for _ in range(layers)
|
| 403 |
+
]
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
def get_cast_dtype(self) -> torch.dtype:
|
| 407 |
+
return self.resblocks[0].get_weight_dtype()
|
| 408 |
+
|
| 409 |
+
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
| 410 |
+
if not self.batch_first:
|
| 411 |
+
x = x.transpose(0, 1).contiguous()
|
| 412 |
+
|
| 413 |
+
for r in self.resblocks:
|
| 414 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
| 415 |
+
x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False)
|
| 416 |
+
else:
|
| 417 |
+
x = r(x, attn_mask=attn_mask)
|
| 418 |
+
|
| 419 |
+
if not self.batch_first:
|
| 420 |
+
x = x.transpose(0, 1)
|
| 421 |
+
return x
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def _expand_token(token, batch_size: int):
|
| 425 |
+
return token.view(1, 1, -1).expand(batch_size, -1, -1)
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def text_global_pool(
|
| 429 |
+
x: torch.Tensor,
|
| 430 |
+
text: Optional[torch.Tensor] = None,
|
| 431 |
+
pool_type: str = "argmax",
|
| 432 |
+
eos_token_id: Optional[int] = None,
|
| 433 |
+
) -> torch.Tensor:
|
| 434 |
+
if pool_type == "first":
|
| 435 |
+
pooled = x[:, 0]
|
| 436 |
+
elif pool_type == "last":
|
| 437 |
+
pooled = x[:, -1]
|
| 438 |
+
elif pool_type == "argmax":
|
| 439 |
+
assert text is not None
|
| 440 |
+
pooled = x[torch.arange(x.shape[0], device=x.device), text.argmax(dim=-1)]
|
| 441 |
+
elif pool_type == "eos":
|
| 442 |
+
assert text is not None
|
| 443 |
+
assert eos_token_id is not None
|
| 444 |
+
idx = (text == eos_token_id).int().argmax(dim=-1)
|
| 445 |
+
pooled = x[torch.arange(x.shape[0], device=x.device), idx]
|
| 446 |
+
else:
|
| 447 |
+
pooled = x
|
| 448 |
+
|
| 449 |
+
return pooled
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
class TextTransformer(nn.Module):
|
| 453 |
+
output_tokens: torch.jit.Final[bool]
|
| 454 |
+
|
| 455 |
+
def __init__(
|
| 456 |
+
self,
|
| 457 |
+
context_length: int = 77,
|
| 458 |
+
vocab_size: int = 49408,
|
| 459 |
+
width: int = 512,
|
| 460 |
+
heads: int = 8,
|
| 461 |
+
layers: int = 12,
|
| 462 |
+
mlp_ratio: float = 4.0,
|
| 463 |
+
ls_init_value: float = None,
|
| 464 |
+
output_dim: Optional[int] = 512,
|
| 465 |
+
embed_cls: bool = False,
|
| 466 |
+
no_causal_mask: bool = False,
|
| 467 |
+
use_pad_mask: bool = False,
|
| 468 |
+
correct_cls_mask: bool = False,
|
| 469 |
+
pad_id: int = 0,
|
| 470 |
+
eos_id: int = 2,
|
| 471 |
+
pool_type: str = "argmax",
|
| 472 |
+
proj_type: str = "linear",
|
| 473 |
+
proj_bias: bool = False,
|
| 474 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
| 475 |
+
norm_layer: Type[nn.Module] = LayerNorm,
|
| 476 |
+
output_tokens: bool = False,
|
| 477 |
+
block_type: Optional[str] = None,
|
| 478 |
+
qk_norm: bool = False,
|
| 479 |
+
scaled_cosine_attn: bool = False,
|
| 480 |
+
scale_heads: bool = False,
|
| 481 |
+
scale_attn_inner: bool = False,
|
| 482 |
+
scale_attn: bool = False,
|
| 483 |
+
scale_fc: bool = False,
|
| 484 |
+
):
|
| 485 |
+
super().__init__()
|
| 486 |
+
assert pool_type in ("first", "last", "argmax", "eos", "none")
|
| 487 |
+
self.output_tokens = output_tokens
|
| 488 |
+
self.num_pos = self.context_length = context_length
|
| 489 |
+
self.vocab_size = vocab_size
|
| 490 |
+
self.width = width
|
| 491 |
+
self.output_dim = output_dim
|
| 492 |
+
self.heads = heads
|
| 493 |
+
self.pad_id = pad_id
|
| 494 |
+
self.eos_id = eos_id
|
| 495 |
+
self.pool_type = pool_type
|
| 496 |
+
self.use_pad_mask = use_pad_mask and no_causal_mask
|
| 497 |
+
self.correct_cls_mask = correct_cls_mask
|
| 498 |
+
|
| 499 |
+
self.token_embedding = nn.Embedding(vocab_size, width)
|
| 500 |
+
if embed_cls:
|
| 501 |
+
self.cls_emb = nn.Parameter(torch.empty(width))
|
| 502 |
+
self.num_pos += 1
|
| 503 |
+
else:
|
| 504 |
+
self.cls_emb = None
|
| 505 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
|
| 506 |
+
self.transformer = Transformer(
|
| 507 |
+
width=width,
|
| 508 |
+
layers=layers,
|
| 509 |
+
heads=heads,
|
| 510 |
+
mlp_ratio=mlp_ratio,
|
| 511 |
+
ls_init_value=ls_init_value,
|
| 512 |
+
act_layer=act_layer,
|
| 513 |
+
norm_layer=norm_layer,
|
| 514 |
+
block_type=block_type,
|
| 515 |
+
qk_norm=qk_norm,
|
| 516 |
+
scaled_cosine_attn=scaled_cosine_attn,
|
| 517 |
+
scale_heads=scale_heads,
|
| 518 |
+
scale_attn_inner=scale_attn_inner,
|
| 519 |
+
scale_attn=scale_attn,
|
| 520 |
+
scale_fc=scale_fc,
|
| 521 |
+
)
|
| 522 |
+
self.ln_final = norm_layer(width)
|
| 523 |
+
|
| 524 |
+
if no_causal_mask:
|
| 525 |
+
self.attn_mask = None
|
| 526 |
+
else:
|
| 527 |
+
self.register_buffer(
|
| 528 |
+
"attn_mask", self.build_causal_mask(), persistent=False
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
if proj_type == "none" or not output_dim:
|
| 532 |
+
self.text_projection = None
|
| 533 |
+
else:
|
| 534 |
+
if proj_bias:
|
| 535 |
+
self.text_projection = nn.Linear(width, output_dim)
|
| 536 |
+
else:
|
| 537 |
+
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
|
| 538 |
+
|
| 539 |
+
self.init_parameters()
|
| 540 |
+
|
| 541 |
+
def init_parameters(self):
|
| 542 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
| 543 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
| 544 |
+
if self.cls_emb is not None:
|
| 545 |
+
nn.init.normal_(self.cls_emb, std=0.01)
|
| 546 |
+
|
| 547 |
+
proj_std = (self.transformer.width**-0.5) * (
|
| 548 |
+
(2 * self.transformer.layers) ** -0.5
|
| 549 |
+
)
|
| 550 |
+
attn_std = self.transformer.width**-0.5
|
| 551 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
| 552 |
+
for block in self.transformer.resblocks:
|
| 553 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
| 554 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
| 555 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
| 556 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
| 557 |
+
|
| 558 |
+
if self.text_projection is not None:
|
| 559 |
+
if isinstance(self.text_projection, nn.Linear):
|
| 560 |
+
nn.init.normal_(
|
| 561 |
+
self.text_projection.weight, std=self.transformer.width**-0.5
|
| 562 |
+
)
|
| 563 |
+
if self.text_projection.bias is not None:
|
| 564 |
+
nn.init.zeros_(self.text_projection.bias)
|
| 565 |
+
else:
|
| 566 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5)
|
| 567 |
+
|
| 568 |
+
def build_causal_mask(self):
|
| 569 |
+
mask = torch.empty(self.num_pos, self.num_pos)
|
| 570 |
+
mask.fill_(float("-inf"))
|
| 571 |
+
mask.triu_(1)
|
| 572 |
+
return mask
|
| 573 |
+
|
| 574 |
+
def _build_additive_mask(self, text, seq_len, dtype):
|
| 575 |
+
valid = text != self.pad_id
|
| 576 |
+
if self.cls_emb is not None:
|
| 577 |
+
cls_valid = valid.new_ones(valid.size(0), 1)
|
| 578 |
+
valid = torch.cat(
|
| 579 |
+
[valid, cls_valid] if self.correct_cls_mask else [cls_valid, valid], 1
|
| 580 |
+
)
|
| 581 |
+
key_mask = valid.unsqueeze(1).expand(-1, seq_len, -1)
|
| 582 |
+
additive = torch.zeros_like(key_mask, dtype=dtype)
|
| 583 |
+
additive.masked_fill_(~key_mask, float("-inf"))
|
| 584 |
+
additive = additive.repeat_interleave(self.heads, 0)
|
| 585 |
+
return additive
|
| 586 |
+
|
| 587 |
+
def _embeds(self, text):
|
| 588 |
+
cast_dtype = self.transformer.get_cast_dtype()
|
| 589 |
+
B, seq_len = text.shape
|
| 590 |
+
x = self.token_embedding(text).to(cast_dtype)
|
| 591 |
+
if self.cls_emb is not None:
|
| 592 |
+
x = torch.cat([x, _expand_token(self.cls_emb, x.size(0))], 1)
|
| 593 |
+
seq_len += 1
|
| 594 |
+
attn_mask = self.attn_mask
|
| 595 |
+
if self.use_pad_mask or self.cls_emb is not None:
|
| 596 |
+
add_mask = self._build_additive_mask(text, seq_len, x.dtype)
|
| 597 |
+
if attn_mask is not None:
|
| 598 |
+
attn_mask = attn_mask[:seq_len, :seq_len].unsqueeze(0) + add_mask
|
| 599 |
+
else:
|
| 600 |
+
attn_mask = add_mask
|
| 601 |
+
x = x + self.positional_embedding[:seq_len].to(cast_dtype)
|
| 602 |
+
return x, attn_mask
|
| 603 |
+
|
| 604 |
+
def forward(self, text):
|
| 605 |
+
x, attn_mask = self._embeds(text)
|
| 606 |
+
x = self.transformer(x, attn_mask=attn_mask)
|
| 607 |
+
if self.cls_emb is not None:
|
| 608 |
+
pooled = text_global_pool(x, pool_type="last")
|
| 609 |
+
pooled = self.ln_final(pooled)
|
| 610 |
+
tokens = x[:, :-1]
|
| 611 |
+
else:
|
| 612 |
+
x = self.ln_final(x)
|
| 613 |
+
pooled = text_global_pool(
|
| 614 |
+
x,
|
| 615 |
+
text,
|
| 616 |
+
pool_type=self.pool_type,
|
| 617 |
+
eos_token_id=getattr(self, "eos_id", None),
|
| 618 |
+
)
|
| 619 |
+
tokens = x
|
| 620 |
+
if self.text_projection is not None:
|
| 621 |
+
if isinstance(self.text_projection, nn.Linear):
|
| 622 |
+
pooled = self.text_projection(pooled)
|
| 623 |
+
else:
|
| 624 |
+
pooled = pooled @ self.text_projection
|
| 625 |
+
if self.output_tokens:
|
| 626 |
+
return pooled, tokens
|
| 627 |
+
return pooled
|
raon_vision_encoder/utils.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Originally from OpenCLIP (https://github.com/mlfoundations/open_clip)
|
| 2 |
+
|
| 3 |
+
import collections.abc
|
| 4 |
+
from itertools import repeat
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def _ntuple(n):
|
| 8 |
+
def parse(x):
|
| 9 |
+
if isinstance(x, collections.abc.Iterable):
|
| 10 |
+
return x
|
| 11 |
+
return tuple(repeat(x, n))
|
| 12 |
+
|
| 13 |
+
return parse
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
to_2tuple = _ntuple(2)
|