Image-Text-to-Text
Transformers
Safetensors
youtu_vl
text-generation
conversational
custom_code
Yinsongliu commited on
Commit
c13c3aa
·
1 Parent(s): fdcd432

Upload model with LFS assets

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
LICENSE.txt ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Tencent is pleased to support the community by making Youtu-Parsing available.
2
+
3
+ Copyright (C) 2026 Tencent. All rights reserved. Youtu-Parsing IS NOT INTENDED FOR USE WITHIN THE EUROPEAN UNION.
4
+
5
+ Youtu-Parsing is licensed under the License Terms of Youtu-Parsing except for the third-party components listed below, which is licensed under different terms. Youtu-Parsing does not impose any additional limitations beyond what is outlined in the respective licenses of these third-party components. Users must comply with all terms and conditions of original licenses of these third-party components and must ensure that the usage of the third party components adheres to all relevant laws and regulations.
6
+
7
+ For avoidance of doubts, Youtu-Parsing refers to the inference enabling code, parameters and weights made publicly available by Tencent in accordance with the License Terms of Youtu-Parsing in this repository.
8
+
9
+ Terms of the License Terms of Youtu-Parsing:
10
+ --------------------------------------------------------------------
11
+
12
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
13
+ 0. Additional Territorial Limitation
14
+
15
+ *Youtu-Parsing IS NOT INTENDED FOR USE WITHIN THE EUROPEAN UNION.*
16
+ IN THE EVENT OF ANY CONFLICT, THIS CLAUSE SHALl PREVAIL.
17
+
18
+ 1. Definitions.
19
+
20
+ "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
21
+
22
+ "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
23
+
24
+ "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
25
+
26
+ "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
27
+
28
+ "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
31
+
32
+ "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
33
+
34
+ "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
35
+
36
+ "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
37
+
38
+ "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
39
+
40
+ 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
41
+
42
+ 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
43
+
44
+ 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
45
+
46
+ You must give any other recipients of the Work or Derivative Works a copy of this License; and
47
+
48
+ You must cause any modified files to carry prominent notices stating that You changed the files; and
49
+
50
+ You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
51
+
52
+ If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
53
+
54
+ You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
55
+
56
+ 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
57
+
58
+ 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
59
+
60
+ 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
61
+
62
+ 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
63
+
64
+ 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
65
+
66
+ END OF TERMS AND CONDITIONS
67
+
68
+
69
+ The Code of this project is built on and with the aid of the following open source projects. Credits are given to these projects.
70
+
71
+ The below software in this distribution may have been modified by Tencent ("Tencent Modifications"). All Tencent Modifications
72
+ are Copyright(C)Tencent.
73
+
74
+ Open Source Software Licensed under the Apache-2.0:
75
+ --------------------------------------------------------------------
76
+ 1. transformers
77
+ Copyright 2018- The Hugging Face team.
78
+
79
+ 2. Opencv
80
+ Copyright (c) 2025 opencv original author and authors
81
+
82
+ 3. vLLM
83
+ Copyright (c) 2025 vllm original author and authors
84
+ Terms of the Apache-2.0:
85
+ --------------------------------------------------------------------
86
+ Apache License
87
+ Version 2.0, January 2004
88
+ http://www.apache.org/licenses/
89
+
90
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
91
+
92
+ 1. Definitions.
93
+
94
+ "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
95
+
96
+ "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
97
+
98
+ "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
99
+
100
+ "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
101
+
102
+ "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
103
+
104
+ "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
105
+
106
+ "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
107
+
108
+ "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
109
+
110
+ "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
111
+
112
+ "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
113
+
114
+ 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
115
+
116
+ 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
117
+
118
+ 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
119
+
120
+ You must give any other recipients of the Work or Derivative Works a copy of this License; and
121
+ You must cause any modified files to carry prominent notices stating that You changed the files; and
122
+ You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
123
+ If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
124
+ You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
125
+
126
+ 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
127
+
128
+ 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
129
+
130
+ 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
131
+
132
+ 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
133
+
134
+ 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
135
+
136
+ END OF TERMS AND CONDITIONS
137
+
138
+
139
+
140
+
141
+
142
+ Open Source Software Licensed under the BSD-3-Clause:
143
+ --------------------------------------------------------------------
144
+ 1. numpy
145
+ Copyright (c) 2005-2024, NumPy Developers.
146
+
147
+ 2. scipy
148
+ Copyright (c) 2001-2002 Enthought, Inc. 2003-2024, SciPy Developers.
149
+
150
+ 3. torch
151
+ Copyright (c) 2016- Facebook, Inc (Adam Paszke); Copyright (c) 2014- Facebook, Inc (Soumith Chintala); Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert); Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu); Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu); Copyright (c) 2011-2013 NYU (Clement Farabet); Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston); Copyright (c) 2006 Idiap Research Institute (Samy Bengio); Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz); Copyright (c) 2016-present, Facebook Inc. All rights reserved.; Copyright (c) 2016 Facebook Inc.; Copyright (c) 2015 Google Inc.; Copyright (c) 2015 Yangqing Jia; Copyright 2019-2020 Kakao Brain; Copyright (c) 2022 Cruise LLC.; Copyright (c) 2024 Tri Dao.; Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates; Copyright(c) 2013, 2014, 2015, the respective contributors; Copyright(c) 2015, 2016 the respective contributors
152
+
153
+ 4. torchvision
154
+ Copyright (c) Soumith Chintala 2016,
155
+ Terms of the BSD-3-Clause:
156
+ --------------------------------------------------------------------
157
+ BSD 3-Clause License
158
+
159
+ Redistribution and use in source and binary forms, with or without
160
+ modification, are permitted provided that the following conditions are met:
161
+
162
+ 1. Redistributions of source code must retain the above copyright notice, this
163
+ list of conditions and the following disclaimer.
164
+
165
+ 2. Redistributions in binary form must reproduce the above copyright notice,
166
+ this list of conditions and the following disclaimer in the documentation
167
+ and/or other materials provided with the distribution.
168
+
169
+ 3. Neither the name of the copyright holder nor the names of its
170
+ contributors may be used to endorse or promote products derived from
171
+ this software without specific prior written permission.
172
+
173
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
174
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
175
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
176
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
177
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
178
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
179
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
180
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
181
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
182
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
183
+
184
+
185
+
186
+
187
+
188
+ Open Source Software Licensed under the MIT-CMU:
189
+ --------------------------------------------------------------------
190
+ 1. Pillow
191
+ Copyright © 1997-2011 by Secret Labs AB; Copyright © 1995-2011 by Fredrik Lundh and contributors; Copyright © 2010 by Jeffrey A. Clark and contributors
192
+ Terms of the MIT-CMU:
193
+ --------------------------------------------------------------------
194
+ By obtaining, using, and/or copying this software and/or its associated documentation, you agree that you have read, understood, and will comply with the following terms and conditions:
195
+ Permission to use, copy, modify, and distribute this software and its associated documentation for any purpose and without fee is hereby granted, provided that the above copyright notice appears in all copies, and that both that copyright notice and this permission notice appear in supporting documentation, and that the name of the copyright holder not be used in advertising or publicity pertaining to distribution of the software without specific, written prior permission.
196
+
197
+ THE COPYRIGHT HOLDER DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE COPYRIGHT HOLDER BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM THE LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
198
+
199
+ ==================================================
200
+ End of the Attribution Notice of this project.
README.md ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: other
3
+ license_name: youtu-parsing
4
+ license_link: https://huggingface.co/tencent/Youtu-Parsing/blob/main/LICENSE.txt
5
+ pipeline_tag: image-text-to-text
6
+ base_model:
7
+ - tencent/Youtu-LLM-2B
8
+ base_model_relation: finetune
9
+ ---
10
+ <div align="center">
11
+
12
+ # <img src="assets/youtu-parsing-logo.png" alt="Youtu-Parsing Logo" height="100px">
13
+
14
+ [📃 License](https://huggingface.co/tencent/Youtu-Parsing/blob/main/LICENSE.txt) • [💻 Code](https://github.com/TencentCloudADP/youtu-parsing) • [📊 Benchmarks](#benchmarks) • [🚀 Getting Started](#quickstart)
15
+
16
+ </div>
17
+
18
+ <div align="center">
19
+ <img src="./assets/static_v40.png" width="800"/>
20
+ </div>
21
+
22
+
23
+ ## 🎯 Introduction
24
+
25
+ **Youtu-Parsing** is a specialized document parsing model built upon the open-source Youtu-LLM 2B foundation. By extending the capabilities of the base model with a prompt-guided framework and NaViT-style dynamic visual encoder, Youtu-Parsing offers enhanced parsing capabilities for diverse document elements including text, tables, formulas, and charts. The model incorporates an efficient parallel decoding mechanism that significantly accelerates inference, making it practical for real-world document analysis applications. We share Youtu-Parsing with the community to facilitate research and development in document understanding.
26
+
27
+
28
+ ## ✨ Key Features
29
+ ### 📄 **Document Structure Preservation**
30
+ - **Text Localization**: Accurately detects and localizes text regions with pixel-level precision, ensuring no content is missed or misplaced across diverse document layouts.
31
+ - **Reading Order Restoration**: Intelligently reconstructs the logical reading sequence of document content, maintaining proper flow across columns, sections, and pages for coherent understanding.
32
+
33
+ ### 📊 **Advanced Content Recognition**
34
+ - **Text Recognition**: Provides precise text recognition across diverse scenarios.
35
+ - **Formula Recognition**: Automatically converts mathematical expressions to LaTeX format.
36
+ - **Table Recognition**: Automatically detects tables and converts them to HTML format.
37
+ - **Chart Recognition**: Converts charts to markdown tables, mind maps and flow charts to mermaid format.
38
+
39
+ ### ⚡ **High-Performance Inference**
40
+ - **Token Parallelism**: Enables simultaneous inference of multiple tokens for accelerated processing, achieving 5-11x speedup.
41
+ - **Query Parallelism**: Combines multiple queries together to maximize Token Parallelism benefits, providing an additional 2x speedup on top of Token Parallelism acceleration.
42
+
43
+ <div align="center">
44
+ <img src="./assets/parallel_decoder.png" width="800"/>
45
+ </div>
46
+
47
+ <a id="benchmarks"></a>
48
+
49
+ ## 📊 Performance
50
+ ### 1. OminiDocBench v1.5
51
+
52
+ <div align="center">
53
+ <img src="./assets/OminiDocBench_v1.5.png" width="800"/>
54
+ </div>
55
+
56
+ ### 2. olmOCR
57
+ <div align="center">
58
+ <img src="./assets/olmOCR.png" width="800"/>
59
+ </div>
60
+
61
+
62
+ <a id="quickstart"></a>
63
+
64
+ ## 🚀 Quick Start
65
+ ### Install packages
66
+ ```bash
67
+ conda create -n youtu_parsing python=3.10
68
+ conda activate youtu_parsing
69
+ pip install git+https://github.com/TencentCloudADP/youtu-parsing.git#subdirectory=youtu_hf_parser
70
+
71
+ # install the flash-attn2
72
+ # For CUDA 12.x + PyTorch 2.6 + Python 3.10 + Linux x86_64:
73
+ pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
74
+
75
+ # Alternative: Install from PyPI
76
+ pip install flash-attn==2.7.0
77
+ ```
78
+
79
+ ### Usage with transformers
80
+ ```python
81
+ from youtu_hf_parser import YoutuOCRParserHF
82
+
83
+ # Initialize the parser
84
+ parser = YoutuOCRParserHF(
85
+ model_path=model_path,
86
+ enable_angle_correct=True, # Set to False to disable angle correction
87
+ angle_correct_model_path=angle_correct_model_path
88
+ )
89
+
90
+ # Parse an image
91
+ parser.parse_file(input_path=image_path, output_dir=output_dir)
92
+ ```
93
+
94
+ ## 🎨 Visualization
95
+ ### Text Recognition
96
+ <div align="center">
97
+ <img src="./assets/handwriting.png" width="800"/>
98
+ </div>
99
+
100
+ <div align="center">
101
+ <img src="./assets/art.png" width="800"/>
102
+ </div>
103
+
104
+ <div align="center">
105
+ <img src="./assets/printed.png" width="800"/>
106
+ </div>
107
+
108
+ ### Formula Recognition
109
+ <div align="center">
110
+ <img src="./assets/formula.png" width="800"/>
111
+ </div>
112
+
113
+ ### Table Recognition
114
+ <div align="center">
115
+ <img src="./assets/wired.png" width="800"/>
116
+ </div>
117
+
118
+ <div align="center">
119
+ <img src="./assets/wireless.png" width="800"/>
120
+ </div>
121
+
122
+ ### Chart Recognition
123
+ <div align="center">
124
+ <img src="./assets/chart.png" width="800"/>
125
+ </div>
126
+
127
+ <div align="center">
128
+ <img src="./assets/chart2.png" width="800"/>
129
+ </div>
130
+
131
+ ### Seal Recognition
132
+ <div align="center">
133
+ <img src="./assets/seal.png" width="800"/>
134
+ </div>
135
+
136
+ ### Hierarchical Structure Analysis
137
+ <div align="center">
138
+ <img src="./assets/hierarchical.png" width="800"/>
139
+ </div>
140
+
141
+ <div align="center">
142
+ <img src="./assets/hierarchical2.png" width="800"/>
143
+ </div>
144
+
145
+ ## 🤝 Acknowledgements
146
+ We would like to thank [Youtu-LLM](https://github.com/TencentCloudADP/youtu-tip/tree/master/youtu-llm), [OmniDocBench](https://github.com/opendatalab/OmniDocBench), [olmOCR](https://github.com/allenai/olmocr), [dots.ocr](https://github.com/rednote-hilab/dots.ocr), [MinerU](https://github.com/opendatalab/MinerU), [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR), [PSENet](https://github.com/whai362/PSENet) for providing model weights, benchmarks and valuable code. We also appreciate everyone's contribution to this open-source project!
147
+
148
+
149
+
150
+ ## 📚 Citation
151
+
152
+ If you find our work useful in your research, please consider citing the following paper:
153
+ ```
154
+ @article{youtu-parsing,
155
+ title={Youtu-Parsing: Perception, Structuring and Recognition via High-Parallelism Decoding},
156
+ author={Tencent Youtu Lab},
157
+ year={2026},
158
+ eprint={},
159
+ archivePrefix={},
160
+ primaryClass={},
161
+ url={},
162
+ }
163
+
164
+ @article{youtu-vl,
165
+ title={Youtu-VL: Unleashing Visual Potential via Unified Vision-Language Supervision},
166
+ author={Tencent Youtu Lab},
167
+ year={2026},
168
+ eprint={},
169
+ archivePrefix={},
170
+ primaryClass={},
171
+ url={},
172
+ }
173
+
174
+ @article{youtu-llm,
175
+ title={Youtu-LLM: Unlocking the Native Agentic Potential for Lightweight Large Language Models},
176
+ author={Tencent Youtu Lab},
177
+ year={2025},
178
+ eprint={2512.24618},
179
+ archivePrefix={arXiv},
180
+ primaryClass={cs.CL},
181
+ url={https://arxiv.org/abs/2512.24618},
182
+ }
183
+ ```
__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The Youtu Team and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+ from transformers.utils import _LazyModule
16
+ from transformers.utils.import_utils import define_import_structure
17
+
18
+ if TYPE_CHECKING:
19
+ from .configuration_youtu_vl import *
20
+ from .modeling_youtu_vl import *
21
+ from .processing_youtu_vl import *
22
+ from .configuration_siglip2 import *
23
+ from .image_processing_siglip2_fast import *
24
+ from .modeling_siglip2 import *
25
+ else:
26
+ import sys
27
+
28
+ _file = globals()["__file__"]
29
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
30
+
31
+
assets/OminiDocBench_v1.5.png ADDED

Git LFS Details

  • SHA256: c98f06c02028df4120e77f53cf6b8f3bd30deb6006432f047f5567c88dbe471c
  • Pointer size: 131 Bytes
  • Size of remote file: 506 kB
assets/art.png ADDED

Git LFS Details

  • SHA256: 0215c94c4b0130f7ef39ae4f69f17a1ea1ac9ed53a6eb76ffc89a56e6c76f5c7
  • Pointer size: 131 Bytes
  • Size of remote file: 784 kB
assets/chart.png ADDED

Git LFS Details

  • SHA256: 6f80a5635f58bedd2c19896b796e0efe81c15b0a053831671edc230d8333aa1f
  • Pointer size: 131 Bytes
  • Size of remote file: 588 kB
assets/chart2.png ADDED

Git LFS Details

  • SHA256: 71412df8dfcfd7a8a6f28cbcbcbd516716a5ee9ed44f64a7c209e70d6fc2838d
  • Pointer size: 131 Bytes
  • Size of remote file: 812 kB
assets/formula.png ADDED

Git LFS Details

  • SHA256: 0a61834add1b26c6c59b808e101c6f91dbee873a184a74a6cb6177016c670c8f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.68 MB
assets/handwriting.png ADDED

Git LFS Details

  • SHA256: b727347801c2885c5b3eee0a86de7a7edfb0bc739754350f464622c930150908
  • Pointer size: 132 Bytes
  • Size of remote file: 1.18 MB
assets/hierarchical.png ADDED

Git LFS Details

  • SHA256: c007c2e869fdb9228663a03b38a80eb539bcae42647660aed6794fd6a4c5a142
  • Pointer size: 132 Bytes
  • Size of remote file: 1.52 MB
assets/hierarchical2.png ADDED

Git LFS Details

  • SHA256: 6c94e8cef1e74977d5b441cb1259b08e5fad12f04800dff75400cd2126812364
  • Pointer size: 132 Bytes
  • Size of remote file: 2.23 MB
assets/olmOCR.png ADDED

Git LFS Details

  • SHA256: 4dd11ce77da4b653b063115c1e7a1821e29a8c78dd57da8dd4898403a153173b
  • Pointer size: 131 Bytes
  • Size of remote file: 391 kB
assets/parallel_decoder.png ADDED

Git LFS Details

  • SHA256: 2f8ff7eac559b195c6f6e21aed23938c28612ded6892c04e357c2a65a87d4f76
  • Pointer size: 131 Bytes
  • Size of remote file: 303 kB
assets/printed.png ADDED

Git LFS Details

  • SHA256: 41716e85a2e29a6844e10c4b46d1ff56d773fa9aa46fe66480b40e7e99e44f23
  • Pointer size: 132 Bytes
  • Size of remote file: 2.23 MB
assets/seal.png ADDED

Git LFS Details

  • SHA256: 5f23fbb7226227a8592c5eb0dc355db372f619ad0bfc5439494648518a6f64a2
  • Pointer size: 131 Bytes
  • Size of remote file: 492 kB
assets/static_v40.png ADDED

Git LFS Details

  • SHA256: 6b115114ec8d011c84da118b202f76586919fdeb8a7c0283c31e6d7c3369b9b4
  • Pointer size: 131 Bytes
  • Size of remote file: 612 kB
assets/wired.png ADDED

Git LFS Details

  • SHA256: 02cf6f58694e62bfd57e2dd00ccc4d359d1536f055e392f98d29707ca80b32dd
  • Pointer size: 132 Bytes
  • Size of remote file: 3.15 MB
assets/wireless.png ADDED

Git LFS Details

  • SHA256: 3de3a6a34eeec6a278d6eaa97d97c4fa7546c39a3d1995de62374038817ad5b7
  • Pointer size: 131 Bytes
  • Size of remote file: 848 kB
assets/youtu-parsing-logo.png ADDED

Git LFS Details

  • SHA256: caa339d0dc0c3f144e14a48708502d587ac581c83709fbac776c34279b81f21d
  • Pointer size: 131 Bytes
  • Size of remote file: 446 kB
chat_template.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "chat_template": "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|begin_of_text|>system\nYou are a helpful assistant.<|end_of_text|>\n{% endif %}<|begin_of_text|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|end_of_text|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|end_of_text|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|begin_of_text|>assistant\n{% endif %}"
3
+ }
config.json ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "YoutuVLForConditionalGeneration"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_youtu_vl.YoutuVLConfig",
9
+ "AutoModelForCausalLM": "modeling_youtu_vl.YoutuVLForConditionalGeneration",
10
+ "AutoProcessor": "processing_youtu_vl.YoutuVLProcessor",
11
+ "AutoImageProcessor": "image_processing_siglip2_fast.Siglip2ImageProcessorFast"
12
+ },
13
+ "bos_token_id": 128000,
14
+ "embedding_initializer_range": 0.02795084971874737,
15
+ "eos_token_id": 128001,
16
+ "head_dim": 64,
17
+ "hidden_act": "silu",
18
+ "hidden_size": 2048,
19
+ "image_token_id": 128264,
20
+ "initializer_range": 0.013975424859373685,
21
+ "intermediate_size": 6144,
22
+ "kv_lora_rank": 512,
23
+ "max_position_embeddings": 20480,
24
+ "mlp_bias": false,
25
+ "model_type": "youtu_vl",
26
+ "num_attention_heads": 16,
27
+ "num_hidden_layers": 32,
28
+ "num_key_value_heads": 16,
29
+ "pad_token_id": 128001,
30
+ "q_lora_rank": 1536,
31
+ "qk_head_dim": 192,
32
+ "qk_nope_head_dim": 128,
33
+ "qk_rope_head_dim": 64,
34
+ "rms_norm_eps": 1e-06,
35
+ "rope_interleave": true,
36
+ "rope_theta": 100000,
37
+ "tie_word_embeddings": true,
38
+ "torch_dtype": "bfloat16",
39
+ "transformers_version": "4.56.0",
40
+ "use_cache": false,
41
+ "v_head_dim": 128,
42
+ "video_token_id": 128265,
43
+ "vision_config": {
44
+ "attention_dropout": 0.0,
45
+ "hidden_act": "gelu_pytorch_tanh",
46
+ "hidden_size": 1152,
47
+ "intermediate_size": 4304,
48
+ "layer_norm_eps": 1e-06,
49
+ "model_type": "siglip2_vision_model",
50
+ "num_attention_heads": 16,
51
+ "num_channels": 3,
52
+ "num_hidden_layers": 27,
53
+ "num_patches": 4096,
54
+ "out_hidden_size": 2048,
55
+ "patch_size": 16,
56
+ "tokens_per_second": 2,
57
+ "torch_dtype": "bfloat16",
58
+ "vision_use_head": false,
59
+ "window_size": 256,
60
+ "fullatt_block_indexes": [
61
+ 7,
62
+ 15,
63
+ 23,
64
+ 26
65
+ ]
66
+ },
67
+ "vision_end_token_id": 128263,
68
+ "vision_start_token_id": 128262,
69
+ "vocab_size": 182646
70
+ }
configuration_siglip2.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ from transformers.utils import logging
3
+
4
+
5
+ logger = logging.get_logger(__name__)
6
+
7
+
8
+ class Siglip2TextConfig(PretrainedConfig):
9
+ r"""
10
+ Args:
11
+ vocab_size (`int`, *optional*, defaults to 32000):
12
+ Vocabulary size of the Siglip2 text model. Defines the number of different tokens that can be represented by
13
+ the `inputs_ids` passed when calling [`Siglip2Model`].
14
+ hidden_size (`int`, *optional*, defaults to 768):
15
+ Dimensionality of the encoder layers and the pooler layer.
16
+ intermediate_size (`int`, *optional*, defaults to 3072):
17
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
18
+ num_hidden_layers (`int`, *optional*, defaults to 12):
19
+ Number of hidden layers in the Transformer encoder.
20
+ num_attention_heads (`int`, *optional*, defaults to 12):
21
+ Number of attention heads for each attention layer in the Transformer encoder.
22
+ max_position_embeddings (`int`, *optional*, defaults to 64):
23
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
24
+ just in case (e.g., 512 or 1024 or 2048).
25
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
26
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
27
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
28
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
29
+ The epsilon used by the layer normalization layers.
30
+ attention_dropout (`float`, *optional*, defaults to 0.0):
31
+ The dropout ratio for the attention probabilities.
32
+ pad_token_id (`int`, *optional*, defaults to 1):
33
+ The id of the padding token in the vocabulary.
34
+ bos_token_id (`int`, *optional*, defaults to 49406):
35
+ The id of the beginning-of-sequence token in the vocabulary.
36
+ eos_token_id (`int`, *optional*, defaults to 49407):
37
+ The id of the end-of-sequence token in the vocabulary.
38
+ projection_size (`int`, *optional*, defaults to `hidden_size`):
39
+ The size of the projection head.
40
+
41
+ """
42
+
43
+ model_type = "siglip2_text_model"
44
+ base_config_key = "text_config"
45
+
46
+ def __init__(
47
+ self,
48
+ vocab_size=32000,
49
+ hidden_size=768,
50
+ intermediate_size=3072,
51
+ num_hidden_layers=12,
52
+ num_attention_heads=12,
53
+ max_position_embeddings=64,
54
+ hidden_act="gelu_pytorch_tanh",
55
+ layer_norm_eps=1e-6,
56
+ attention_dropout=0.0,
57
+ pad_token_id=1,
58
+ bos_token_id=49406,
59
+ eos_token_id=49407,
60
+ projection_size=None,
61
+ **kwargs,
62
+ ):
63
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
64
+
65
+ self.vocab_size = vocab_size
66
+ self.hidden_size = hidden_size
67
+ self.intermediate_size = intermediate_size
68
+ self.num_hidden_layers = num_hidden_layers
69
+ self.num_attention_heads = num_attention_heads
70
+ self.max_position_embeddings = max_position_embeddings
71
+ self.layer_norm_eps = layer_norm_eps
72
+ self.hidden_act = hidden_act
73
+ self.attention_dropout = attention_dropout
74
+ self.projection_size = projection_size if projection_size is not None else hidden_size
75
+
76
+
77
+ class Siglip2VisionConfig(PretrainedConfig):
78
+ r"""
79
+ Args:
80
+ hidden_size (`int`, *optional*, defaults to 768):
81
+ Dimensionality of the encoder layers and the pooler layer.
82
+ intermediate_size (`int`, *optional*, defaults to 3072):
83
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
84
+ num_hidden_layers (`int`, *optional*, defaults to 12):
85
+ Number of hidden layers in the Transformer encoder.
86
+ num_attention_heads (`int`, *optional*, defaults to 12):
87
+ Number of attention heads for each attention layer in the Transformer encoder.
88
+ num_channels (`int`, *optional*, defaults to 3):
89
+ Number of channels in the input images.
90
+ num_patches (`int`, *optional*, defaults to 256):
91
+ The number of patches in the image with the size of (`patch_size`, `patch_size`).
92
+ The image is resized to fill maximum of this number of patches, and to preserve
93
+ the aspect ratio. In case the resulted number of patches is lower, the image is
94
+ padded in "patch" dimension.
95
+ patch_size (`int`, *optional*, defaults to 16):
96
+ The size (resolution) of each patch.
97
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
98
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
99
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
100
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
101
+ The epsilon used by the layer normalization layers.
102
+ attention_dropout (`float`, *optional*, defaults to 0.0):
103
+ The dropout ratio for the attention probabilities.
104
+
105
+ """
106
+
107
+ model_type = "siglip2_vision_model"
108
+ base_config_key = "vision_config"
109
+
110
+ def __init__(
111
+ self,
112
+ hidden_size=768,
113
+ out_hidden_size=2048,
114
+ intermediate_size=3072,
115
+ num_hidden_layers=12,
116
+ num_attention_heads=12,
117
+ num_channels=3,
118
+ num_patches=256,
119
+ patch_size=16,
120
+ hidden_act="gelu_pytorch_tanh",
121
+ layer_norm_eps=1e-6,
122
+ attention_dropout=0.0,
123
+ **kwargs,
124
+ ):
125
+ super().__init__(**kwargs)
126
+
127
+ self.hidden_size = hidden_size
128
+ self.out_hidden_size = out_hidden_size
129
+ self.intermediate_size = intermediate_size
130
+ self.num_hidden_layers = num_hidden_layers
131
+ self.num_attention_heads = num_attention_heads
132
+ self.num_channels = num_channels
133
+ self.patch_size = patch_size
134
+ self.attention_dropout = attention_dropout
135
+ self.layer_norm_eps = layer_norm_eps
136
+ self.hidden_act = hidden_act
137
+ self.num_patches = num_patches
138
+ self.in_features = -1
139
+
140
+
141
+ class Siglip2Config(PretrainedConfig):
142
+ r"""
143
+ Args:
144
+ text_config (`dict`, *optional*):
145
+ Dictionary of configuration options used to initialize [`Siglip2TextConfig`].
146
+ vision_config (`dict`, *optional*):
147
+ Dictionary of configuration options used to initialize [`Siglip2VisionConfig`].
148
+ kwargs (*optional*):
149
+ Dictionary of keyword arguments.
150
+
151
+ """
152
+
153
+ model_type = "siglip2"
154
+ sub_configs = {"text_config": Siglip2TextConfig, "vision_config": Siglip2VisionConfig}
155
+
156
+ def __init__(self, text_config=None, vision_config=None, **kwargs):
157
+ super().__init__(**kwargs)
158
+
159
+ if text_config is None:
160
+ text_config = {}
161
+ logger.info("`text_config` is `None`. Initializing the `Siglip2TextConfig` with default values.")
162
+
163
+ if vision_config is None:
164
+ vision_config = {}
165
+ logger.info("`vision_config` is `None`. initializing the `Siglip2VisionConfig` with default values.")
166
+
167
+ self.text_config = Siglip2TextConfig(**text_config)
168
+ self.vision_config = Siglip2VisionConfig(**vision_config)
169
+
170
+ self.initializer_factor = 1.0
171
+
172
+ @classmethod
173
+ def from_text_vision_configs(cls, text_config: Siglip2TextConfig, vision_config: Siglip2VisionConfig, **kwargs):
174
+
175
+ return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
176
+
177
+
178
+ __all__ = ["Siglip2Config", "Siglip2TextConfig", "Siglip2VisionConfig"]
configuration_youtu_vl.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ from transformers.modeling_rope_utils import rope_config_validation
3
+ from .configuration_siglip2 import Siglip2VisionConfig
4
+
5
+
6
+ YOUTU_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
7
+
8
+ class YoutuVLConfig(PretrainedConfig):
9
+ r"""
10
+ Args:
11
+ vocab_size (`int`, *optional*, defaults to 129280):
12
+ Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
13
+ `inputs_ids` passed when calling [`YoutuModel`]
14
+ hidden_size (`int`, *optional*, defaults to 7168):
15
+ Dimension of the hidden representations.
16
+ intermediate_size (`int`, *optional*, defaults to 18432):
17
+ Dimension of the MLP representations.
18
+ num_hidden_layers (`int`, *optional*, defaults to 61):
19
+ Number of hidden layers in the Transformer decoder.
20
+ num_attention_heads (`int`, *optional*, defaults to 128):
21
+ Number of attention heads for each attention layer in the Transformer decoder.
22
+ num_key_value_heads (`int`, *optional*, defaults to 128):
23
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
24
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
25
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
26
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
27
+ by meanpooling all the original heads within that group. For more details checkout [this
28
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
29
+ `num_attention_heads`.
30
+ n_shared_experts (`int`, *optional*, defaults to 1):
31
+ Number of shared experts.
32
+ n_routed_experts (`int`, *optional*, defaults to 256):
33
+ Number of routed experts.
34
+ routed_scaling_factor (`float`, *optional*, defaults to 2.5):
35
+ Scaling factor or routed experts.
36
+ kv_lora_rank (`int`, *optional*, defaults to 512):
37
+ Rank of the LoRA matrices for key and value projections.
38
+ q_lora_rank (`int`, *optional*, defaults to 1536):
39
+ Rank of the LoRA matrices for query projections.
40
+ qk_rope_head_dim (`int`, *optional*, defaults to 64):
41
+ Dimension of the query/key heads that use rotary position embeddings.
42
+ v_head_dim (`int`, *optional*, defaults to 128):
43
+ Dimension of the value heads.
44
+ qk_nope_head_dim (`int`, *optional*, defaults to 128):
45
+ Dimension of the query/key heads that don't use rotary position embeddings.
46
+ n_group (`int`, *optional*, defaults to 8):
47
+ Number of groups for routed experts.
48
+ topk_group (`int`, *optional*, defaults to 4):
49
+ Number of selected groups for each token.
50
+ num_experts_per_tok (`int`, *optional*, defaults to 8):
51
+ Number of selected experts, None means dense model.
52
+ norm_topk_prob (`bool`, *optional*, defaults to `True`):
53
+ Whether to normalize the weights of the routed experts.
54
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
55
+ The non-linear activation function (function or string) in the decoder.
56
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
57
+ The maximum sequence length that this model might ever be used with.
58
+ initializer_range (`float`, *optional*, defaults to 0.02):
59
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
60
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
61
+ The epsilon used by the rms normalization layers.
62
+ use_cache (`bool`, *optional*, defaults to `True`):
63
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
64
+ relevant if `config.is_decoder=True`.
65
+ pad_token_id (`int`, *optional*):
66
+ Padding token id.
67
+ bos_token_id (`int`, *optional*, defaults to 0):
68
+ Beginning of stream token id.
69
+ eos_token_id (`int`, *optional*, defaults to 1):
70
+ End of stream token id.
71
+ pretraining_tp (`int`, *optional*, defaults to 1):
72
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
73
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
74
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
75
+ issue](https://github.com/pytorch/pytorch/issues/76232).
76
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
77
+ Whether to tie weight embeddings
78
+ rope_theta (`float`, *optional*, defaults to 10000.0):
79
+ The base period of the RoPE embeddings.
80
+ rope_scaling (`Dict`, *optional*):
81
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
82
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
83
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
84
+ `max_position_embeddings` to the expected new maximum.
85
+ rope_interleave (`bool`, *optional*, defaults to `True`):
86
+ Whether to interleave the rotary position embeddings.
87
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
88
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
89
+ attention_dropout (`float`, *optional*, defaults to 0.0):
90
+ The dropout ratio for the attention probabilities.
91
+
92
+ """
93
+
94
+ sub_configs = {"vision_config": Siglip2VisionConfig}
95
+ model_type = "youtu_vl"
96
+ keys_to_ignore_at_inference = ["past_key_values"]
97
+ base_model_pp_plan = {
98
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
99
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
100
+ "norm": (["hidden_states"], ["hidden_states"]),
101
+ }
102
+
103
+ def __init__(
104
+ self,
105
+ vocab_size=129280,
106
+ hidden_size=7168,
107
+ intermediate_size=18432,
108
+ num_hidden_layers=61,
109
+ num_attention_heads=128,
110
+ num_key_value_heads=128,
111
+ n_shared_experts=1,
112
+ n_routed_experts=256,
113
+ routed_scaling_factor=2.5,
114
+ kv_lora_rank=512,
115
+ q_lora_rank=1536,
116
+ qk_rope_head_dim=64,
117
+ v_head_dim=128,
118
+ qk_nope_head_dim=128,
119
+ n_group=8,
120
+ topk_group=4,
121
+ num_experts_per_tok=8,
122
+ norm_topk_prob=True,
123
+ hidden_act="silu",
124
+ max_position_embeddings=4096,
125
+ initializer_range=None,
126
+ embedding_initializer_range=None,
127
+ rms_norm_eps=1e-6,
128
+ use_cache=True,
129
+ pad_token_id=None,
130
+ bos_token_id=0,
131
+ eos_token_id=1,
132
+ pretraining_tp=1,
133
+ tie_word_embeddings=False,
134
+ rope_theta=10000.0,
135
+ rope_scaling=None,
136
+ rope_interleave=True,
137
+ attention_bias=False,
138
+ attention_dropout=0.0,
139
+ vision_config=None,
140
+ **kwargs,
141
+ ):
142
+ if isinstance(vision_config, dict):
143
+ self.vision_config = self.sub_configs["vision_config"](**vision_config)
144
+ elif vision_config is None:
145
+ self.vision_config = self.sub_configs["vision_config"]()
146
+
147
+ self.vocab_size = vocab_size
148
+ self.max_position_embeddings = max_position_embeddings
149
+ self.hidden_size = hidden_size
150
+ self.intermediate_size = intermediate_size
151
+ self.num_hidden_layers = num_hidden_layers
152
+ self.num_attention_heads = num_attention_heads
153
+ self.n_shared_experts = n_shared_experts
154
+ self.n_routed_experts = n_routed_experts
155
+ self.routed_scaling_factor = routed_scaling_factor
156
+ self.kv_lora_rank = kv_lora_rank
157
+ self.q_lora_rank = q_lora_rank
158
+ self.qk_rope_head_dim = qk_rope_head_dim
159
+ self.v_head_dim = v_head_dim
160
+ self.qk_nope_head_dim = qk_nope_head_dim
161
+ self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
162
+ self.head_dim = qk_rope_head_dim
163
+ self.n_group = n_group
164
+ self.topk_group = topk_group
165
+ self.num_experts_per_tok = num_experts_per_tok
166
+ self.norm_topk_prob = norm_topk_prob
167
+ self.rope_interleave = rope_interleave
168
+ self.flash_att_sliding_window = None
169
+
170
+ self.mlp_bias = False
171
+ self.mtp_loss_weight = 0.3
172
+
173
+ if num_key_value_heads is None:
174
+ num_key_value_heads = num_attention_heads
175
+
176
+ self.num_key_value_heads = num_key_value_heads
177
+ self.hidden_act = hidden_act
178
+ self.initializer_range = (
179
+ (2.0 / (5.0 * self.hidden_size)) ** 0.5
180
+ if initializer_range is None
181
+ else initializer_range
182
+ )
183
+ self.embedding_initializer_range = (
184
+ self.initializer_range * 2.0
185
+ if embedding_initializer_range is None
186
+ else embedding_initializer_range
187
+ )
188
+ self.rms_norm_eps = rms_norm_eps
189
+ self.pretraining_tp = pretraining_tp
190
+ self.use_cache = use_cache
191
+ self.rope_theta = rope_theta
192
+ self.rope_scaling = rope_scaling
193
+ self.attention_bias = attention_bias
194
+ self.attention_dropout = attention_dropout
195
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
196
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
197
+ rope_config_validation(self)
198
+
199
+ super().__init__(
200
+ pad_token_id=pad_token_id,
201
+ bos_token_id=bos_token_id,
202
+ eos_token_id=eos_token_id,
203
+ tie_word_embeddings=tie_word_embeddings,
204
+ **kwargs,
205
+ )
206
+
207
+
208
+ __all__ = ["YoutuVLConfig"]
209
+
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 128000,
4
+ "eos_token_id": 128001,
5
+ "pad_token_id": 128001,
6
+ "transformers_version": "4.56.0"
7
+ }
image_processing_siglip2_fast.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+ import os
3
+ import torch
4
+ import math
5
+ from torchvision.transforms import functional as F
6
+ from transformers.image_processing_utils import BatchFeature
7
+ from transformers.image_processing_utils_fast import (
8
+ BaseImageProcessorFast,
9
+ DefaultFastImageProcessorKwargs,
10
+ SizeDict,
11
+ )
12
+ from transformers.image_utils import (
13
+ ImageInput,
14
+ PILImageResampling,
15
+ )
16
+ from transformers.processing_utils import Unpack
17
+ from transformers.utils import (
18
+ TensorType,
19
+ add_start_docstrings,
20
+ is_torch_available,
21
+ is_torchvision_available,
22
+ is_torchvision_v2_available,
23
+ logging,
24
+ )
25
+
26
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING = r"""
27
+
28
+ Args:
29
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
30
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
31
+ `do_resize` parameter in the `preprocess` method.
32
+ size (`dict`, *optional*, defaults to `self.size`):
33
+ Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
34
+ method.
35
+ default_to_square (`bool`, *optional*, defaults to `self.default_to_square`):
36
+ Whether to default to a square image when resizing, if size is an int.
37
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
38
+ Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
39
+ overridden by the `resample` parameter in the `preprocess` method.
40
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
41
+ Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the
42
+ `preprocess` method.
43
+ crop_size (`Dict[str, int]` *optional*, defaults to `self.crop_size`):
44
+ Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess`
45
+ method.
46
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
47
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
48
+ `do_rescale` parameter in the `preprocess` method.
49
+ rescale_factor (`int` or `float`, *optional*, defaults to `self.rescale_factor`):
50
+ Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
51
+ overridden by the `rescale_factor` parameter in the `preprocess` method.
52
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
53
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
54
+ method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
55
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
56
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
57
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
58
+ overridden by the `image_mean` parameter in the `preprocess` method.
59
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
60
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
61
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
62
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
63
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
64
+ Whether to convert the image to RGB.
65
+ return_tensors (`str` or `TensorType`, *optional*, defaults to `self.return_tensors`):
66
+ Returns stacked tensors if set to `pt, otherwise returns a list of tensors.
67
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `self.data_format`):
68
+ Only `ChannelDimension.FIRST` is supported. Added for compatibility with slow processors.
69
+ input_data_format (`ChannelDimension` or `str`, *optional*, defaults to `self.input_data_format`):
70
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
71
+ from the input image. Can be one of:
72
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
73
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
74
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
75
+ device (`torch.device`, *optional*, defaults to `self.device`):
76
+ The device to process the images on. If unset, the device is inferred from the input images."""
77
+
78
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS = r"""
79
+ Preprocess an image or batch of images.
80
+
81
+ Args:
82
+ images (`ImageInput`):
83
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
84
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
85
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
86
+ Whether to resize the image.
87
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
88
+ Describes the maximum input dimensions to the model.
89
+ resample (`PILImageResampling` or `InterpolationMode`, *optional*, defaults to `self.resample`):
90
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
91
+ has an effect if `do_resize` is set to `True`.
92
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
93
+ Whether to center crop the image.
94
+ crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
95
+ Size of the output image after applying `center_crop`.
96
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
97
+ Whether to rescale the image.
98
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
99
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
100
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
101
+ Whether to normalize the image.
102
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
103
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
104
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
105
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
106
+ `True`.
107
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
108
+ Whether to convert the image to RGB.
109
+ return_tensors (`str` or `TensorType`, *optional*, defaults to `self.return_tensors`):
110
+ Returns stacked tensors if set to `pt, otherwise returns a list of tensors.
111
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `self.data_format`):
112
+ Only `ChannelDimension.FIRST` is supported. Added for compatibility with slow processors.
113
+ input_data_format (`ChannelDimension` or `str`, *optional*, defaults to `self.input_data_format`):
114
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
115
+ from the input image. Can be one of:
116
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
117
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
118
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
119
+ device (`torch.device`, *optional*, defaults to `self.device`):
120
+ The device to process the images on. If unset, the device is inferred from the input images."""
121
+
122
+
123
+ if is_torch_available():
124
+ import torch
125
+
126
+ if is_torchvision_available():
127
+ if is_torchvision_v2_available():
128
+ from torchvision.transforms.v2 import functional as F
129
+ else:
130
+ from torchvision.transforms import functional as F
131
+
132
+
133
+ logger = logging.get_logger(__name__)
134
+
135
+
136
+ def get_image_size_for_patches(
137
+ image_height: int, image_width: int, patch_size: int, max_num_patches: int
138
+ ) -> Tuple[int, int]:
139
+ """
140
+ Args:
141
+ image_height (`int`):
142
+ Original image height.
143
+ image_width (`int`):
144
+ Original image width.
145
+ patch_size (`int`):
146
+ Patch size for processing.
147
+
148
+ Returns:
149
+ Tuple: (target_height, target_width)
150
+ """
151
+
152
+ def get_scaled_image_size(scale: float, size: int, patch_size: int) -> int:
153
+ patch_size = patch_size * 2
154
+ scaled_size = size * scale
155
+ scaled_size = math.ceil(scaled_size / patch_size) * patch_size
156
+ scaled_size = max(patch_size, scaled_size)
157
+ return int(scaled_size)
158
+
159
+ scale = 1.0
160
+ while True:
161
+ target_height = get_scaled_image_size(scale, image_height, patch_size)
162
+ target_width = get_scaled_image_size(scale, image_width, patch_size)
163
+ num_patches = (target_height / patch_size) * (target_width / patch_size)
164
+
165
+ if num_patches > max_num_patches:
166
+ scale -= 0.02
167
+ else:
168
+ break
169
+
170
+ return target_height, target_width
171
+
172
+
173
+ def convert_image_to_patches(image: "torch.Tensor", patch_size: int, merge_size: int) -> "torch.Tensor":
174
+ """
175
+ Converts an input image into flattened patches.
176
+
177
+ Args:
178
+ image: Input image tensor of shape (channels, height, width)
179
+ patch_size: Size of each square patch (in pixels)
180
+ merge_size: Number of adjacent patches to merge
181
+
182
+ """
183
+
184
+ num_channels, image_height, image_width = image.shape
185
+ num_patches_height = image_height // patch_size
186
+ num_patches_width = image_width // patch_size
187
+ patched_image = image.reshape(num_channels,
188
+ num_patches_height//merge_size,
189
+ merge_size, patch_size,
190
+ num_patches_width//merge_size,
191
+ merge_size, patch_size)
192
+ patched_image = patched_image.permute(1, 4, 2, 5, 3, 6, 0)
193
+ patched_image = patched_image.reshape(num_patches_height * num_patches_width, -1)
194
+ return patched_image
195
+
196
+ def pad_along_first_dim(
197
+ tensor: "torch.Tensor", target_length: int, pad_value: int = 0
198
+ ) -> Tuple["torch.Tensor", "torch.Tensor"]:
199
+ """
200
+ Pad the input tensor along its first dimension to a target length.
201
+
202
+ Args:
203
+ tensor (torch.Tensor): The input tensor to be padded.
204
+ target_length (int): The desired length of the first dimension after padding.
205
+ pad_value (int, optional): The value to use for padding. Defaults to 0.
206
+ """
207
+ current_length = tensor.shape[0]
208
+ padding_length = target_length - current_length
209
+ mask = torch.ones((target_length,), dtype=torch.int32)
210
+ if padding_length > 0:
211
+ padding = [0, 0] * (tensor.ndim - 1) + [0, padding_length]
212
+ tensor = torch.nn.functional.pad(tensor, padding, mode="constant", value=pad_value)
213
+ mask[-padding_length:] = 0
214
+ return tensor, mask
215
+
216
+
217
+ class Siglip2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
218
+ patch_size: Optional[int]
219
+ max_num_patches: Optional[int]
220
+
221
+
222
+ @add_start_docstrings(
223
+ r"Constructs a fast Siglip2 image processor.",
224
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
225
+ """
226
+ patch_size (`int`, *optional*, defaults to 16):
227
+ The size (resolution) of each patch the image will be split to.
228
+ max_num_patches (`int`, *optional*, defaults to 256):
229
+ The image will be resized to have at most this number of patches,
230
+ and then padded in "patch" dimension to match this number exactly.
231
+ """,
232
+ )
233
+ class Siglip2ImageProcessorFast(BaseImageProcessorFast):
234
+ resample = PILImageResampling.BILINEAR
235
+ image_mean = [0.5, 0.5, 0.5]
236
+ image_std = [0.5, 0.5, 0.5]
237
+ do_resize = True
238
+ do_rescale = True
239
+ do_normalize = True
240
+ patch_size = 16
241
+ max_num_patches = 256
242
+ valid_kwargs = Siglip2FastImageProcessorKwargs
243
+ unused_kwargs = ["size", "do_center_crop", "crop_size"]
244
+ print_max_patched = True
245
+
246
+ def __init__(self, **kwargs: Unpack[Siglip2FastImageProcessorKwargs]):
247
+ super().__init__(**kwargs)
248
+
249
+ def _validate_preprocess_kwargs(self, **kwargs) -> tuple:
250
+ kwargs.pop("do_resize", None)
251
+ return super()._validate_preprocess_kwargs(**kwargs)
252
+
253
+ @add_start_docstrings(
254
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
255
+ """
256
+ patch_size (`int`, *optional*, defaults to `self.patch_size`):
257
+ The size (resolution) of each patch the image will be split to.
258
+ max_num_patches (`int`, *optional*, defaults to `self.max_num_patches`):
259
+ The image will be resized to have at most this number of patches,
260
+ and then padded in "patch" dimension to match this number exactly.
261
+ """,
262
+ )
263
+ def preprocess(self, images: ImageInput, **kwargs: Unpack[Siglip2FastImageProcessorKwargs]) -> BatchFeature:
264
+ return super().preprocess(images, **kwargs)
265
+
266
+ def get_max_image_patches(self, images):
267
+ return 4096 * 6 * 6
268
+
269
+ def _preprocess(
270
+ self,
271
+ images: List["torch.Tensor"],
272
+ do_resize: bool,
273
+ patch_size: int,
274
+ max_num_patches: int,
275
+ interpolation: Optional["F.InterpolationMode"],
276
+ do_rescale: bool,
277
+ rescale_factor: float,
278
+ do_normalize: bool,
279
+ image_mean: Optional[Union[float, List[float]]],
280
+ image_std: Optional[Union[float, List[float]]],
281
+ return_tensors: Optional[Union[str, TensorType]],
282
+ **kwargs,
283
+ ) -> BatchFeature:
284
+ pixel_masks = []
285
+ pixel_values = []
286
+ spatial_shapes = []
287
+
288
+ if Siglip2ImageProcessorFast.print_max_patched:
289
+ Siglip2ImageProcessorFast.print_max_patched = False
290
+
291
+ for i, image in enumerate(images):
292
+ height, width, = get_image_size_for_patches(
293
+ image_height=image.shape[1],
294
+ image_width=image.shape[2],
295
+ patch_size=patch_size,
296
+ max_num_patches=max_num_patches,
297
+ )
298
+
299
+ side_dict = SizeDict(height=height, width=width)
300
+ image = self.resize(image=image, size=side_dict, interpolation=interpolation)
301
+ image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
302
+
303
+ patches = convert_image_to_patches(image, patch_size, 2)
304
+ patches, mask = pad_along_first_dim(patches, len(patches))
305
+
306
+ num_patches_height = image.shape[1] // patch_size
307
+ num_patches_width = image.shape[2] // patch_size
308
+
309
+ spatial_shapes.append((num_patches_height, num_patches_width))
310
+ pixel_values.append(patches)
311
+ pixel_masks.append(mask)
312
+
313
+ pixel_values = torch.stack(pixel_values, dim=0)
314
+ pixel_masks = torch.stack(pixel_masks, dim=0)
315
+ spatial_shapes = torch.tensor(spatial_shapes)
316
+
317
+ batch_feature = BatchFeature(
318
+ data={
319
+ "pixel_values": pixel_values,
320
+ "pixel_attention_mask": pixel_masks,
321
+ "spatial_shapes": spatial_shapes,
322
+ },
323
+ tensor_type=return_tensors,
324
+ )
325
+ return batch_feature
326
+
327
+
328
+ __all__ = ["Siglip2ImageProcessorFast"]
model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2f63df3c26a31537118c524d375566cd022b3f5c2ec7e14aa096f7cd5d2bfc2
3
+ size 4981728088
model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7620ba07a90acb406e36436ac513f691c675a2ad2d459bbb5fae053b0a507c3
3
+ size 50344496
model.safetensors.index.json ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 5779947232
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "model-00002-of-00002.safetensors",
7
+ "merger.ln_q.weight": "model-00001-of-00002.safetensors",
8
+ "merger.mlp.0.bias": "model-00001-of-00002.safetensors",
9
+ "merger.mlp.0.weight": "model-00001-of-00002.safetensors",
10
+ "merger.mlp.2.bias": "model-00001-of-00002.safetensors",
11
+ "merger.mlp.2.weight": "model-00001-of-00002.safetensors",
12
+ "model.embed_tokens.weight": "model-00001-of-00002.safetensors",
13
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
14
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
15
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
16
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
17
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
18
+ "model.layers.0.self_attn.kv_a_layernorm.weight": "model-00001-of-00002.safetensors",
19
+ "model.layers.0.self_attn.kv_a_proj_with_mqa.weight": "model-00001-of-00002.safetensors",
20
+ "model.layers.0.self_attn.kv_b_proj.weight": "model-00001-of-00002.safetensors",
21
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
22
+ "model.layers.0.self_attn.q_a_layernorm.weight": "model-00001-of-00002.safetensors",
23
+ "model.layers.0.self_attn.q_a_proj.weight": "model-00001-of-00002.safetensors",
24
+ "model.layers.0.self_attn.q_b_proj.weight": "model-00001-of-00002.safetensors",
25
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
26
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
27
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
28
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
29
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
30
+ "model.layers.1.self_attn.kv_a_layernorm.weight": "model-00001-of-00002.safetensors",
31
+ "model.layers.1.self_attn.kv_a_proj_with_mqa.weight": "model-00001-of-00002.safetensors",
32
+ "model.layers.1.self_attn.kv_b_proj.weight": "model-00001-of-00002.safetensors",
33
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
34
+ "model.layers.1.self_attn.q_a_layernorm.weight": "model-00001-of-00002.safetensors",
35
+ "model.layers.1.self_attn.q_a_proj.weight": "model-00001-of-00002.safetensors",
36
+ "model.layers.1.self_attn.q_b_proj.weight": "model-00001-of-00002.safetensors",
37
+ "model.layers.10.input_layernorm.weight": "model-00001-of-00002.safetensors",
38
+ "model.layers.10.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
39
+ "model.layers.10.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
40
+ "model.layers.10.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
41
+ "model.layers.10.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
42
+ "model.layers.10.self_attn.kv_a_layernorm.weight": "model-00001-of-00002.safetensors",
43
+ "model.layers.10.self_attn.kv_a_proj_with_mqa.weight": "model-00001-of-00002.safetensors",
44
+ "model.layers.10.self_attn.kv_b_proj.weight": "model-00001-of-00002.safetensors",
45
+ "model.layers.10.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
46
+ "model.layers.10.self_attn.q_a_layernorm.weight": "model-00001-of-00002.safetensors",
47
+ "model.layers.10.self_attn.q_a_proj.weight": "model-00001-of-00002.safetensors",
48
+ "model.layers.10.self_attn.q_b_proj.weight": "model-00001-of-00002.safetensors",
49
+ "model.layers.11.input_layernorm.weight": "model-00001-of-00002.safetensors",
50
+ "model.layers.11.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
51
+ "model.layers.11.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
52
+ "model.layers.11.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
53
+ "model.layers.11.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
54
+ "model.layers.11.self_attn.kv_a_layernorm.weight": "model-00001-of-00002.safetensors",
55
+ "model.layers.11.self_attn.kv_a_proj_with_mqa.weight": "model-00001-of-00002.safetensors",
56
+ "model.layers.11.self_attn.kv_b_proj.weight": "model-00001-of-00002.safetensors",
57
+ "model.layers.11.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
58
+ "model.layers.11.self_attn.q_a_layernorm.weight": "model-00001-of-00002.safetensors",
59
+ "model.layers.11.self_attn.q_a_proj.weight": "model-00001-of-00002.safetensors",
60
+ "model.layers.11.self_attn.q_b_proj.weight": "model-00001-of-00002.safetensors",
61
+ "model.layers.12.input_layernorm.weight": "model-00001-of-00002.safetensors",
62
+ "model.layers.12.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
63
+ "model.layers.12.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
64
+ "model.layers.12.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
65
+ "model.layers.12.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
66
+ "model.layers.12.self_attn.kv_a_layernorm.weight": "model-00001-of-00002.safetensors",
67
+ "model.layers.12.self_attn.kv_a_proj_with_mqa.weight": "model-00001-of-00002.safetensors",
68
+ "model.layers.12.self_attn.kv_b_proj.weight": "model-00001-of-00002.safetensors",
69
+ "model.layers.12.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
70
+ "model.layers.12.self_attn.q_a_layernorm.weight": "model-00001-of-00002.safetensors",
71
+ "model.layers.12.self_attn.q_a_proj.weight": "model-00001-of-00002.safetensors",
72
+ "model.layers.12.self_attn.q_b_proj.weight": "model-00001-of-00002.safetensors",
73
+ "model.layers.13.input_layernorm.weight": "model-00001-of-00002.safetensors",
74
+ "model.layers.13.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
75
+ "model.layers.13.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
76
+ "model.layers.13.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
77
+ "model.layers.13.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
78
+ "model.layers.13.self_attn.kv_a_layernorm.weight": "model-00001-of-00002.safetensors",
79
+ "model.layers.13.self_attn.kv_a_proj_with_mqa.weight": "model-00001-of-00002.safetensors",
80
+ "model.layers.13.self_attn.kv_b_proj.weight": "model-00001-of-00002.safetensors",
81
+ "model.layers.13.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
82
+ "model.layers.13.self_attn.q_a_layernorm.weight": "model-00001-of-00002.safetensors",
83
+ "model.layers.13.self_attn.q_a_proj.weight": "model-00001-of-00002.safetensors",
84
+ "model.layers.13.self_attn.q_b_proj.weight": "model-00001-of-00002.safetensors",
85
+ "model.layers.14.input_layernorm.weight": "model-00001-of-00002.safetensors",
86
+ "model.layers.14.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
87
+ "model.layers.14.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
88
+ "model.layers.14.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
89
+ "model.layers.14.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
90
+ "model.layers.14.self_attn.kv_a_layernorm.weight": "model-00001-of-00002.safetensors",
91
+ "model.layers.14.self_attn.kv_a_proj_with_mqa.weight": "model-00001-of-00002.safetensors",
92
+ "model.layers.14.self_attn.kv_b_proj.weight": "model-00001-of-00002.safetensors",
93
+ "model.layers.14.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
94
+ "model.layers.14.self_attn.q_a_layernorm.weight": "model-00001-of-00002.safetensors",
95
+ "model.layers.14.self_attn.q_a_proj.weight": "model-00001-of-00002.safetensors",
96
+ "model.layers.14.self_attn.q_b_proj.weight": "model-00001-of-00002.safetensors",
97
+ "model.layers.15.input_layernorm.weight": "model-00001-of-00002.safetensors",
98
+ "model.layers.15.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
99
+ "model.layers.15.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
100
+ "model.layers.15.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
101
+ "model.layers.15.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
102
+ "model.layers.15.self_attn.kv_a_layernorm.weight": "model-00001-of-00002.safetensors",
103
+ "model.layers.15.self_attn.kv_a_proj_with_mqa.weight": "model-00001-of-00002.safetensors",
104
+ "model.layers.15.self_attn.kv_b_proj.weight": "model-00001-of-00002.safetensors",
105
+ "model.layers.15.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
106
+ "model.layers.15.self_attn.q_a_layernorm.weight": "model-00001-of-00002.safetensors",
107
+ "model.layers.15.self_attn.q_a_proj.weight": "model-00001-of-00002.safetensors",
108
+ "model.layers.15.self_attn.q_b_proj.weight": "model-00001-of-00002.safetensors",
109
+ "model.layers.16.input_layernorm.weight": "model-00001-of-00002.safetensors",
110
+ "model.layers.16.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
111
+ "model.layers.16.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
112
+ "model.layers.16.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
113
+ "model.layers.16.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
114
+ "model.layers.16.self_attn.kv_a_layernorm.weight": "model-00001-of-00002.safetensors",
115
+ "model.layers.16.self_attn.kv_a_proj_with_mqa.weight": "model-00001-of-00002.safetensors",
116
+ "model.layers.16.self_attn.kv_b_proj.weight": "model-00001-of-00002.safetensors",
117
+ "model.layers.16.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
118
+ "model.layers.16.self_attn.q_a_layernorm.weight": "model-00001-of-00002.safetensors",
119
+ "model.layers.16.self_attn.q_a_proj.weight": "model-00001-of-00002.safetensors",
120
+ "model.layers.16.self_attn.q_b_proj.weight": "model-00001-of-00002.safetensors",
121
+ "model.layers.17.input_layernorm.weight": "model-00001-of-00002.safetensors",
122
+ "model.layers.17.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
123
+ "model.layers.17.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
124
+ "model.layers.17.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
125
+ "model.layers.17.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
126
+ "model.layers.17.self_attn.kv_a_layernorm.weight": "model-00001-of-00002.safetensors",
127
+ "model.layers.17.self_attn.kv_a_proj_with_mqa.weight": "model-00001-of-00002.safetensors",
128
+ "model.layers.17.self_attn.kv_b_proj.weight": "model-00001-of-00002.safetensors",
129
+ "model.layers.17.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
130
+ "model.layers.17.self_attn.q_a_layernorm.weight": "model-00001-of-00002.safetensors",
131
+ "model.layers.17.self_attn.q_a_proj.weight": "model-00001-of-00002.safetensors",
132
+ "model.layers.17.self_attn.q_b_proj.weight": "model-00001-of-00002.safetensors",
133
+ "model.layers.18.input_layernorm.weight": "model-00001-of-00002.safetensors",
134
+ "model.layers.18.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
135
+ "model.layers.18.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
136
+ "model.layers.18.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
137
+ "model.layers.18.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
138
+ "model.layers.18.self_attn.kv_a_layernorm.weight": "model-00001-of-00002.safetensors",
139
+ "model.layers.18.self_attn.kv_a_proj_with_mqa.weight": "model-00001-of-00002.safetensors",
140
+ "model.layers.18.self_attn.kv_b_proj.weight": "model-00001-of-00002.safetensors",
141
+ "model.layers.18.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
142
+ "model.layers.18.self_attn.q_a_layernorm.weight": "model-00001-of-00002.safetensors",
143
+ "model.layers.18.self_attn.q_a_proj.weight": "model-00001-of-00002.safetensors",
144
+ "model.layers.18.self_attn.q_b_proj.weight": "model-00001-of-00002.safetensors",
145
+ "model.layers.19.input_layernorm.weight": "model-00001-of-00002.safetensors",
146
+ "model.layers.19.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
147
+ "model.layers.19.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
148
+ "model.layers.19.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
149
+ "model.layers.19.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
150
+ "model.layers.19.self_attn.kv_a_layernorm.weight": "model-00001-of-00002.safetensors",
151
+ "model.layers.19.self_attn.kv_a_proj_with_mqa.weight": "model-00001-of-00002.safetensors",
152
+ "model.layers.19.self_attn.kv_b_proj.weight": "model-00001-of-00002.safetensors",
153
+ "model.layers.19.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
154
+ "model.layers.19.self_attn.q_a_layernorm.weight": "model-00001-of-00002.safetensors",
155
+ "model.layers.19.self_attn.q_a_proj.weight": "model-00001-of-00002.safetensors",
156
+ "model.layers.19.self_attn.q_b_proj.weight": "model-00001-of-00002.safetensors",
157
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
158
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
159
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
160
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
161
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
162
+ "model.layers.2.self_attn.kv_a_layernorm.weight": "model-00001-of-00002.safetensors",
163
+ "model.layers.2.self_attn.kv_a_proj_with_mqa.weight": "model-00001-of-00002.safetensors",
164
+ "model.layers.2.self_attn.kv_b_proj.weight": "model-00001-of-00002.safetensors",
165
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
166
+ "model.layers.2.self_attn.q_a_layernorm.weight": "model-00001-of-00002.safetensors",
167
+ "model.layers.2.self_attn.q_a_proj.weight": "model-00001-of-00002.safetensors",
168
+ "model.layers.2.self_attn.q_b_proj.weight": "model-00001-of-00002.safetensors",
169
+ "model.layers.20.input_layernorm.weight": "model-00001-of-00002.safetensors",
170
+ "model.layers.20.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
171
+ "model.layers.20.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
172
+ "model.layers.20.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
173
+ "model.layers.20.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
174
+ "model.layers.20.self_attn.kv_a_layernorm.weight": "model-00001-of-00002.safetensors",
175
+ "model.layers.20.self_attn.kv_a_proj_with_mqa.weight": "model-00001-of-00002.safetensors",
176
+ "model.layers.20.self_attn.kv_b_proj.weight": "model-00001-of-00002.safetensors",
177
+ "model.layers.20.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
178
+ "model.layers.20.self_attn.q_a_layernorm.weight": "model-00001-of-00002.safetensors",
179
+ "model.layers.20.self_attn.q_a_proj.weight": "model-00001-of-00002.safetensors",
180
+ "model.layers.20.self_attn.q_b_proj.weight": "model-00001-of-00002.safetensors",
181
+ "model.layers.21.input_layernorm.weight": "model-00001-of-00002.safetensors",
182
+ "model.layers.21.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
183
+ "model.layers.21.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
184
+ "model.layers.21.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
185
+ "model.layers.21.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
186
+ "model.layers.21.self_attn.kv_a_layernorm.weight": "model-00001-of-00002.safetensors",
187
+ "model.layers.21.self_attn.kv_a_proj_with_mqa.weight": "model-00001-of-00002.safetensors",
188
+ "model.layers.21.self_attn.kv_b_proj.weight": "model-00001-of-00002.safetensors",
189
+ "model.layers.21.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
190
+ "model.layers.21.self_attn.q_a_layernorm.weight": "model-00001-of-00002.safetensors",
191
+ "model.layers.21.self_attn.q_a_proj.weight": "model-00001-of-00002.safetensors",
192
+ "model.layers.21.self_attn.q_b_proj.weight": "model-00001-of-00002.safetensors",
193
+ "model.layers.22.input_layernorm.weight": "model-00001-of-00002.safetensors",
194
+ "model.layers.22.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
195
+ "model.layers.22.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
196
+ "model.layers.22.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
197
+ "model.layers.22.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
198
+ "model.layers.22.self_attn.kv_a_layernorm.weight": "model-00001-of-00002.safetensors",
199
+ "model.layers.22.self_attn.kv_a_proj_with_mqa.weight": "model-00001-of-00002.safetensors",
200
+ "model.layers.22.self_attn.kv_b_proj.weight": "model-00001-of-00002.safetensors",
201
+ "model.layers.22.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
202
+ "model.layers.22.self_attn.q_a_layernorm.weight": "model-00001-of-00002.safetensors",
203
+ "model.layers.22.self_attn.q_a_proj.weight": "model-00001-of-00002.safetensors",
204
+ "model.layers.22.self_attn.q_b_proj.weight": "model-00001-of-00002.safetensors",
205
+ "model.layers.23.input_layernorm.weight": "model-00001-of-00002.safetensors",
206
+ "model.layers.23.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
207
+ "model.layers.23.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
208
+ "model.layers.23.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
209
+ "model.layers.23.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
210
+ "model.layers.23.self_attn.kv_a_layernorm.weight": "model-00001-of-00002.safetensors",
211
+ "model.layers.23.self_attn.kv_a_proj_with_mqa.weight": "model-00001-of-00002.safetensors",
212
+ "model.layers.23.self_attn.kv_b_proj.weight": "model-00001-of-00002.safetensors",
213
+ "model.layers.23.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
214
+ "model.layers.23.self_attn.q_a_layernorm.weight": "model-00001-of-00002.safetensors",
215
+ "model.layers.23.self_attn.q_a_proj.weight": "model-00001-of-00002.safetensors",
216
+ "model.layers.23.self_attn.q_b_proj.weight": "model-00001-of-00002.safetensors",
217
+ "model.layers.24.input_layernorm.weight": "model-00001-of-00002.safetensors",
218
+ "model.layers.24.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
219
+ "model.layers.24.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
220
+ "model.layers.24.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
221
+ "model.layers.24.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
222
+ "model.layers.24.self_attn.kv_a_layernorm.weight": "model-00001-of-00002.safetensors",
223
+ "model.layers.24.self_attn.kv_a_proj_with_mqa.weight": "model-00001-of-00002.safetensors",
224
+ "model.layers.24.self_attn.kv_b_proj.weight": "model-00001-of-00002.safetensors",
225
+ "model.layers.24.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
226
+ "model.layers.24.self_attn.q_a_layernorm.weight": "model-00001-of-00002.safetensors",
227
+ "model.layers.24.self_attn.q_a_proj.weight": "model-00001-of-00002.safetensors",
228
+ "model.layers.24.self_attn.q_b_proj.weight": "model-00001-of-00002.safetensors",
229
+ "model.layers.25.input_layernorm.weight": "model-00001-of-00002.safetensors",
230
+ "model.layers.25.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
231
+ "model.layers.25.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
232
+ "model.layers.25.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
233
+ "model.layers.25.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
234
+ "model.layers.25.self_attn.kv_a_layernorm.weight": "model-00001-of-00002.safetensors",
235
+ "model.layers.25.self_attn.kv_a_proj_with_mqa.weight": "model-00001-of-00002.safetensors",
236
+ "model.layers.25.self_attn.kv_b_proj.weight": "model-00001-of-00002.safetensors",
237
+ "model.layers.25.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
238
+ "model.layers.25.self_attn.q_a_layernorm.weight": "model-00001-of-00002.safetensors",
239
+ "model.layers.25.self_attn.q_a_proj.weight": "model-00001-of-00002.safetensors",
240
+ "model.layers.25.self_attn.q_b_proj.weight": "model-00001-of-00002.safetensors",
241
+ "model.layers.26.input_layernorm.weight": "model-00001-of-00002.safetensors",
242
+ "model.layers.26.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
243
+ "model.layers.26.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
244
+ "model.layers.26.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
245
+ "model.layers.26.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
246
+ "model.layers.26.self_attn.kv_a_layernorm.weight": "model-00001-of-00002.safetensors",
247
+ "model.layers.26.self_attn.kv_a_proj_with_mqa.weight": "model-00001-of-00002.safetensors",
248
+ "model.layers.26.self_attn.kv_b_proj.weight": "model-00001-of-00002.safetensors",
249
+ "model.layers.26.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
250
+ "model.layers.26.self_attn.q_a_layernorm.weight": "model-00001-of-00002.safetensors",
251
+ "model.layers.26.self_attn.q_a_proj.weight": "model-00001-of-00002.safetensors",
252
+ "model.layers.26.self_attn.q_b_proj.weight": "model-00001-of-00002.safetensors",
253
+ "model.layers.27.input_layernorm.weight": "model-00001-of-00002.safetensors",
254
+ "model.layers.27.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
255
+ "model.layers.27.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
256
+ "model.layers.27.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
257
+ "model.layers.27.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
258
+ "model.layers.27.self_attn.kv_a_layernorm.weight": "model-00001-of-00002.safetensors",
259
+ "model.layers.27.self_attn.kv_a_proj_with_mqa.weight": "model-00001-of-00002.safetensors",
260
+ "model.layers.27.self_attn.kv_b_proj.weight": "model-00001-of-00002.safetensors",
261
+ "model.layers.27.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
262
+ "model.layers.27.self_attn.q_a_layernorm.weight": "model-00001-of-00002.safetensors",
263
+ "model.layers.27.self_attn.q_a_proj.weight": "model-00001-of-00002.safetensors",
264
+ "model.layers.27.self_attn.q_b_proj.weight": "model-00001-of-00002.safetensors",
265
+ "model.layers.28.input_layernorm.weight": "model-00001-of-00002.safetensors",
266
+ "model.layers.28.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
267
+ "model.layers.28.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
268
+ "model.layers.28.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
269
+ "model.layers.28.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
270
+ "model.layers.28.self_attn.kv_a_layernorm.weight": "model-00001-of-00002.safetensors",
271
+ "model.layers.28.self_attn.kv_a_proj_with_mqa.weight": "model-00001-of-00002.safetensors",
272
+ "model.layers.28.self_attn.kv_b_proj.weight": "model-00001-of-00002.safetensors",
273
+ "model.layers.28.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
274
+ "model.layers.28.self_attn.q_a_layernorm.weight": "model-00001-of-00002.safetensors",
275
+ "model.layers.28.self_attn.q_a_proj.weight": "model-00001-of-00002.safetensors",
276
+ "model.layers.28.self_attn.q_b_proj.weight": "model-00001-of-00002.safetensors",
277
+ "model.layers.29.input_layernorm.weight": "model-00001-of-00002.safetensors",
278
+ "model.layers.29.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
279
+ "model.layers.29.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
280
+ "model.layers.29.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
281
+ "model.layers.29.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
282
+ "model.layers.29.self_attn.kv_a_layernorm.weight": "model-00001-of-00002.safetensors",
283
+ "model.layers.29.self_attn.kv_a_proj_with_mqa.weight": "model-00001-of-00002.safetensors",
284
+ "model.layers.29.self_attn.kv_b_proj.weight": "model-00001-of-00002.safetensors",
285
+ "model.layers.29.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
286
+ "model.layers.29.self_attn.q_a_layernorm.weight": "model-00001-of-00002.safetensors",
287
+ "model.layers.29.self_attn.q_a_proj.weight": "model-00001-of-00002.safetensors",
288
+ "model.layers.29.self_attn.q_b_proj.weight": "model-00001-of-00002.safetensors",
289
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
290
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
291
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
292
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
293
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
294
+ "model.layers.3.self_attn.kv_a_layernorm.weight": "model-00001-of-00002.safetensors",
295
+ "model.layers.3.self_attn.kv_a_proj_with_mqa.weight": "model-00001-of-00002.safetensors",
296
+ "model.layers.3.self_attn.kv_b_proj.weight": "model-00001-of-00002.safetensors",
297
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
298
+ "model.layers.3.self_attn.q_a_layernorm.weight": "model-00001-of-00002.safetensors",
299
+ "model.layers.3.self_attn.q_a_proj.weight": "model-00001-of-00002.safetensors",
300
+ "model.layers.3.self_attn.q_b_proj.weight": "model-00001-of-00002.safetensors",
301
+ "model.layers.30.input_layernorm.weight": "model-00001-of-00002.safetensors",
302
+ "model.layers.30.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
303
+ "model.layers.30.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
304
+ "model.layers.30.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
305
+ "model.layers.30.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
306
+ "model.layers.30.self_attn.kv_a_layernorm.weight": "model-00001-of-00002.safetensors",
307
+ "model.layers.30.self_attn.kv_a_proj_with_mqa.weight": "model-00001-of-00002.safetensors",
308
+ "model.layers.30.self_attn.kv_b_proj.weight": "model-00001-of-00002.safetensors",
309
+ "model.layers.30.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
310
+ "model.layers.30.self_attn.q_a_layernorm.weight": "model-00001-of-00002.safetensors",
311
+ "model.layers.30.self_attn.q_a_proj.weight": "model-00001-of-00002.safetensors",
312
+ "model.layers.30.self_attn.q_b_proj.weight": "model-00001-of-00002.safetensors",
313
+ "model.layers.31.input_layernorm.weight": "model-00002-of-00002.safetensors",
314
+ "model.layers.31.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
315
+ "model.layers.31.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
316
+ "model.layers.31.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
317
+ "model.layers.31.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
318
+ "model.layers.31.self_attn.kv_a_layernorm.weight": "model-00001-of-00002.safetensors",
319
+ "model.layers.31.self_attn.kv_a_proj_with_mqa.weight": "model-00001-of-00002.safetensors",
320
+ "model.layers.31.self_attn.kv_b_proj.weight": "model-00001-of-00002.safetensors",
321
+ "model.layers.31.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
322
+ "model.layers.31.self_attn.q_a_layernorm.weight": "model-00001-of-00002.safetensors",
323
+ "model.layers.31.self_attn.q_a_proj.weight": "model-00001-of-00002.safetensors",
324
+ "model.layers.31.self_attn.q_b_proj.weight": "model-00001-of-00002.safetensors",
325
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00002.safetensors",
326
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
327
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
328
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
329
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
330
+ "model.layers.4.self_attn.kv_a_layernorm.weight": "model-00001-of-00002.safetensors",
331
+ "model.layers.4.self_attn.kv_a_proj_with_mqa.weight": "model-00001-of-00002.safetensors",
332
+ "model.layers.4.self_attn.kv_b_proj.weight": "model-00001-of-00002.safetensors",
333
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
334
+ "model.layers.4.self_attn.q_a_layernorm.weight": "model-00001-of-00002.safetensors",
335
+ "model.layers.4.self_attn.q_a_proj.weight": "model-00001-of-00002.safetensors",
336
+ "model.layers.4.self_attn.q_b_proj.weight": "model-00001-of-00002.safetensors",
337
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00002.safetensors",
338
+ "model.layers.5.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
339
+ "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
340
+ "model.layers.5.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
341
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
342
+ "model.layers.5.self_attn.kv_a_layernorm.weight": "model-00001-of-00002.safetensors",
343
+ "model.layers.5.self_attn.kv_a_proj_with_mqa.weight": "model-00001-of-00002.safetensors",
344
+ "model.layers.5.self_attn.kv_b_proj.weight": "model-00001-of-00002.safetensors",
345
+ "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
346
+ "model.layers.5.self_attn.q_a_layernorm.weight": "model-00001-of-00002.safetensors",
347
+ "model.layers.5.self_attn.q_a_proj.weight": "model-00001-of-00002.safetensors",
348
+ "model.layers.5.self_attn.q_b_proj.weight": "model-00001-of-00002.safetensors",
349
+ "model.layers.6.input_layernorm.weight": "model-00001-of-00002.safetensors",
350
+ "model.layers.6.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
351
+ "model.layers.6.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
352
+ "model.layers.6.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
353
+ "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
354
+ "model.layers.6.self_attn.kv_a_layernorm.weight": "model-00001-of-00002.safetensors",
355
+ "model.layers.6.self_attn.kv_a_proj_with_mqa.weight": "model-00001-of-00002.safetensors",
356
+ "model.layers.6.self_attn.kv_b_proj.weight": "model-00001-of-00002.safetensors",
357
+ "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
358
+ "model.layers.6.self_attn.q_a_layernorm.weight": "model-00001-of-00002.safetensors",
359
+ "model.layers.6.self_attn.q_a_proj.weight": "model-00001-of-00002.safetensors",
360
+ "model.layers.6.self_attn.q_b_proj.weight": "model-00001-of-00002.safetensors",
361
+ "model.layers.7.input_layernorm.weight": "model-00001-of-00002.safetensors",
362
+ "model.layers.7.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
363
+ "model.layers.7.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
364
+ "model.layers.7.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
365
+ "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
366
+ "model.layers.7.self_attn.kv_a_layernorm.weight": "model-00001-of-00002.safetensors",
367
+ "model.layers.7.self_attn.kv_a_proj_with_mqa.weight": "model-00001-of-00002.safetensors",
368
+ "model.layers.7.self_attn.kv_b_proj.weight": "model-00001-of-00002.safetensors",
369
+ "model.layers.7.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
370
+ "model.layers.7.self_attn.q_a_layernorm.weight": "model-00001-of-00002.safetensors",
371
+ "model.layers.7.self_attn.q_a_proj.weight": "model-00001-of-00002.safetensors",
372
+ "model.layers.7.self_attn.q_b_proj.weight": "model-00001-of-00002.safetensors",
373
+ "model.layers.8.input_layernorm.weight": "model-00001-of-00002.safetensors",
374
+ "model.layers.8.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
375
+ "model.layers.8.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
376
+ "model.layers.8.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
377
+ "model.layers.8.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
378
+ "model.layers.8.self_attn.kv_a_layernorm.weight": "model-00001-of-00002.safetensors",
379
+ "model.layers.8.self_attn.kv_a_proj_with_mqa.weight": "model-00001-of-00002.safetensors",
380
+ "model.layers.8.self_attn.kv_b_proj.weight": "model-00001-of-00002.safetensors",
381
+ "model.layers.8.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
382
+ "model.layers.8.self_attn.q_a_layernorm.weight": "model-00001-of-00002.safetensors",
383
+ "model.layers.8.self_attn.q_a_proj.weight": "model-00001-of-00002.safetensors",
384
+ "model.layers.8.self_attn.q_b_proj.weight": "model-00001-of-00002.safetensors",
385
+ "model.layers.9.input_layernorm.weight": "model-00001-of-00002.safetensors",
386
+ "model.layers.9.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
387
+ "model.layers.9.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
388
+ "model.layers.9.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
389
+ "model.layers.9.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
390
+ "model.layers.9.self_attn.kv_a_layernorm.weight": "model-00001-of-00002.safetensors",
391
+ "model.layers.9.self_attn.kv_a_proj_with_mqa.weight": "model-00001-of-00002.safetensors",
392
+ "model.layers.9.self_attn.kv_b_proj.weight": "model-00001-of-00002.safetensors",
393
+ "model.layers.9.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
394
+ "model.layers.9.self_attn.q_a_layernorm.weight": "model-00001-of-00002.safetensors",
395
+ "model.layers.9.self_attn.q_a_proj.weight": "model-00001-of-00002.safetensors",
396
+ "model.layers.9.self_attn.q_b_proj.weight": "model-00001-of-00002.safetensors",
397
+ "model.norm.weight": "model-00002-of-00002.safetensors",
398
+ "siglip2.vision_model.embeddings.patch_embedding.bias": "model-00001-of-00002.safetensors",
399
+ "siglip2.vision_model.embeddings.patch_embedding.weight": "model-00001-of-00002.safetensors",
400
+ "siglip2.vision_model.encoder.layers.0.layer_norm1.bias": "model-00001-of-00002.safetensors",
401
+ "siglip2.vision_model.encoder.layers.0.layer_norm1.weight": "model-00001-of-00002.safetensors",
402
+ "siglip2.vision_model.encoder.layers.0.layer_norm2.bias": "model-00001-of-00002.safetensors",
403
+ "siglip2.vision_model.encoder.layers.0.layer_norm2.weight": "model-00001-of-00002.safetensors",
404
+ "siglip2.vision_model.encoder.layers.0.mlp.fc1.bias": "model-00001-of-00002.safetensors",
405
+ "siglip2.vision_model.encoder.layers.0.mlp.fc1.weight": "model-00001-of-00002.safetensors",
406
+ "siglip2.vision_model.encoder.layers.0.mlp.fc2.bias": "model-00001-of-00002.safetensors",
407
+ "siglip2.vision_model.encoder.layers.0.mlp.fc2.weight": "model-00001-of-00002.safetensors",
408
+ "siglip2.vision_model.encoder.layers.0.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
409
+ "siglip2.vision_model.encoder.layers.0.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
410
+ "siglip2.vision_model.encoder.layers.0.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
411
+ "siglip2.vision_model.encoder.layers.0.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
412
+ "siglip2.vision_model.encoder.layers.0.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
413
+ "siglip2.vision_model.encoder.layers.0.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
414
+ "siglip2.vision_model.encoder.layers.0.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
415
+ "siglip2.vision_model.encoder.layers.0.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
416
+ "siglip2.vision_model.encoder.layers.1.layer_norm1.bias": "model-00001-of-00002.safetensors",
417
+ "siglip2.vision_model.encoder.layers.1.layer_norm1.weight": "model-00001-of-00002.safetensors",
418
+ "siglip2.vision_model.encoder.layers.1.layer_norm2.bias": "model-00001-of-00002.safetensors",
419
+ "siglip2.vision_model.encoder.layers.1.layer_norm2.weight": "model-00001-of-00002.safetensors",
420
+ "siglip2.vision_model.encoder.layers.1.mlp.fc1.bias": "model-00001-of-00002.safetensors",
421
+ "siglip2.vision_model.encoder.layers.1.mlp.fc1.weight": "model-00001-of-00002.safetensors",
422
+ "siglip2.vision_model.encoder.layers.1.mlp.fc2.bias": "model-00001-of-00002.safetensors",
423
+ "siglip2.vision_model.encoder.layers.1.mlp.fc2.weight": "model-00001-of-00002.safetensors",
424
+ "siglip2.vision_model.encoder.layers.1.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
425
+ "siglip2.vision_model.encoder.layers.1.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
426
+ "siglip2.vision_model.encoder.layers.1.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
427
+ "siglip2.vision_model.encoder.layers.1.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
428
+ "siglip2.vision_model.encoder.layers.1.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
429
+ "siglip2.vision_model.encoder.layers.1.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
430
+ "siglip2.vision_model.encoder.layers.1.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
431
+ "siglip2.vision_model.encoder.layers.1.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
432
+ "siglip2.vision_model.encoder.layers.10.layer_norm1.bias": "model-00001-of-00002.safetensors",
433
+ "siglip2.vision_model.encoder.layers.10.layer_norm1.weight": "model-00001-of-00002.safetensors",
434
+ "siglip2.vision_model.encoder.layers.10.layer_norm2.bias": "model-00001-of-00002.safetensors",
435
+ "siglip2.vision_model.encoder.layers.10.layer_norm2.weight": "model-00001-of-00002.safetensors",
436
+ "siglip2.vision_model.encoder.layers.10.mlp.fc1.bias": "model-00001-of-00002.safetensors",
437
+ "siglip2.vision_model.encoder.layers.10.mlp.fc1.weight": "model-00001-of-00002.safetensors",
438
+ "siglip2.vision_model.encoder.layers.10.mlp.fc2.bias": "model-00001-of-00002.safetensors",
439
+ "siglip2.vision_model.encoder.layers.10.mlp.fc2.weight": "model-00001-of-00002.safetensors",
440
+ "siglip2.vision_model.encoder.layers.10.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
441
+ "siglip2.vision_model.encoder.layers.10.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
442
+ "siglip2.vision_model.encoder.layers.10.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
443
+ "siglip2.vision_model.encoder.layers.10.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
444
+ "siglip2.vision_model.encoder.layers.10.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
445
+ "siglip2.vision_model.encoder.layers.10.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
446
+ "siglip2.vision_model.encoder.layers.10.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
447
+ "siglip2.vision_model.encoder.layers.10.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
448
+ "siglip2.vision_model.encoder.layers.11.layer_norm1.bias": "model-00001-of-00002.safetensors",
449
+ "siglip2.vision_model.encoder.layers.11.layer_norm1.weight": "model-00001-of-00002.safetensors",
450
+ "siglip2.vision_model.encoder.layers.11.layer_norm2.bias": "model-00001-of-00002.safetensors",
451
+ "siglip2.vision_model.encoder.layers.11.layer_norm2.weight": "model-00001-of-00002.safetensors",
452
+ "siglip2.vision_model.encoder.layers.11.mlp.fc1.bias": "model-00001-of-00002.safetensors",
453
+ "siglip2.vision_model.encoder.layers.11.mlp.fc1.weight": "model-00001-of-00002.safetensors",
454
+ "siglip2.vision_model.encoder.layers.11.mlp.fc2.bias": "model-00001-of-00002.safetensors",
455
+ "siglip2.vision_model.encoder.layers.11.mlp.fc2.weight": "model-00001-of-00002.safetensors",
456
+ "siglip2.vision_model.encoder.layers.11.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
457
+ "siglip2.vision_model.encoder.layers.11.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
458
+ "siglip2.vision_model.encoder.layers.11.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
459
+ "siglip2.vision_model.encoder.layers.11.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
460
+ "siglip2.vision_model.encoder.layers.11.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
461
+ "siglip2.vision_model.encoder.layers.11.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
462
+ "siglip2.vision_model.encoder.layers.11.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
463
+ "siglip2.vision_model.encoder.layers.11.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
464
+ "siglip2.vision_model.encoder.layers.12.layer_norm1.bias": "model-00001-of-00002.safetensors",
465
+ "siglip2.vision_model.encoder.layers.12.layer_norm1.weight": "model-00001-of-00002.safetensors",
466
+ "siglip2.vision_model.encoder.layers.12.layer_norm2.bias": "model-00001-of-00002.safetensors",
467
+ "siglip2.vision_model.encoder.layers.12.layer_norm2.weight": "model-00001-of-00002.safetensors",
468
+ "siglip2.vision_model.encoder.layers.12.mlp.fc1.bias": "model-00001-of-00002.safetensors",
469
+ "siglip2.vision_model.encoder.layers.12.mlp.fc1.weight": "model-00001-of-00002.safetensors",
470
+ "siglip2.vision_model.encoder.layers.12.mlp.fc2.bias": "model-00001-of-00002.safetensors",
471
+ "siglip2.vision_model.encoder.layers.12.mlp.fc2.weight": "model-00001-of-00002.safetensors",
472
+ "siglip2.vision_model.encoder.layers.12.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
473
+ "siglip2.vision_model.encoder.layers.12.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
474
+ "siglip2.vision_model.encoder.layers.12.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
475
+ "siglip2.vision_model.encoder.layers.12.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
476
+ "siglip2.vision_model.encoder.layers.12.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
477
+ "siglip2.vision_model.encoder.layers.12.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
478
+ "siglip2.vision_model.encoder.layers.12.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
479
+ "siglip2.vision_model.encoder.layers.12.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
480
+ "siglip2.vision_model.encoder.layers.13.layer_norm1.bias": "model-00001-of-00002.safetensors",
481
+ "siglip2.vision_model.encoder.layers.13.layer_norm1.weight": "model-00001-of-00002.safetensors",
482
+ "siglip2.vision_model.encoder.layers.13.layer_norm2.bias": "model-00001-of-00002.safetensors",
483
+ "siglip2.vision_model.encoder.layers.13.layer_norm2.weight": "model-00001-of-00002.safetensors",
484
+ "siglip2.vision_model.encoder.layers.13.mlp.fc1.bias": "model-00001-of-00002.safetensors",
485
+ "siglip2.vision_model.encoder.layers.13.mlp.fc1.weight": "model-00001-of-00002.safetensors",
486
+ "siglip2.vision_model.encoder.layers.13.mlp.fc2.bias": "model-00001-of-00002.safetensors",
487
+ "siglip2.vision_model.encoder.layers.13.mlp.fc2.weight": "model-00001-of-00002.safetensors",
488
+ "siglip2.vision_model.encoder.layers.13.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
489
+ "siglip2.vision_model.encoder.layers.13.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
490
+ "siglip2.vision_model.encoder.layers.13.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
491
+ "siglip2.vision_model.encoder.layers.13.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
492
+ "siglip2.vision_model.encoder.layers.13.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
493
+ "siglip2.vision_model.encoder.layers.13.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
494
+ "siglip2.vision_model.encoder.layers.13.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
495
+ "siglip2.vision_model.encoder.layers.13.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
496
+ "siglip2.vision_model.encoder.layers.14.layer_norm1.bias": "model-00001-of-00002.safetensors",
497
+ "siglip2.vision_model.encoder.layers.14.layer_norm1.weight": "model-00001-of-00002.safetensors",
498
+ "siglip2.vision_model.encoder.layers.14.layer_norm2.bias": "model-00001-of-00002.safetensors",
499
+ "siglip2.vision_model.encoder.layers.14.layer_norm2.weight": "model-00001-of-00002.safetensors",
500
+ "siglip2.vision_model.encoder.layers.14.mlp.fc1.bias": "model-00001-of-00002.safetensors",
501
+ "siglip2.vision_model.encoder.layers.14.mlp.fc1.weight": "model-00001-of-00002.safetensors",
502
+ "siglip2.vision_model.encoder.layers.14.mlp.fc2.bias": "model-00001-of-00002.safetensors",
503
+ "siglip2.vision_model.encoder.layers.14.mlp.fc2.weight": "model-00001-of-00002.safetensors",
504
+ "siglip2.vision_model.encoder.layers.14.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
505
+ "siglip2.vision_model.encoder.layers.14.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
506
+ "siglip2.vision_model.encoder.layers.14.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
507
+ "siglip2.vision_model.encoder.layers.14.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
508
+ "siglip2.vision_model.encoder.layers.14.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
509
+ "siglip2.vision_model.encoder.layers.14.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
510
+ "siglip2.vision_model.encoder.layers.14.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
511
+ "siglip2.vision_model.encoder.layers.14.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
512
+ "siglip2.vision_model.encoder.layers.15.layer_norm1.bias": "model-00001-of-00002.safetensors",
513
+ "siglip2.vision_model.encoder.layers.15.layer_norm1.weight": "model-00001-of-00002.safetensors",
514
+ "siglip2.vision_model.encoder.layers.15.layer_norm2.bias": "model-00001-of-00002.safetensors",
515
+ "siglip2.vision_model.encoder.layers.15.layer_norm2.weight": "model-00001-of-00002.safetensors",
516
+ "siglip2.vision_model.encoder.layers.15.mlp.fc1.bias": "model-00001-of-00002.safetensors",
517
+ "siglip2.vision_model.encoder.layers.15.mlp.fc1.weight": "model-00001-of-00002.safetensors",
518
+ "siglip2.vision_model.encoder.layers.15.mlp.fc2.bias": "model-00001-of-00002.safetensors",
519
+ "siglip2.vision_model.encoder.layers.15.mlp.fc2.weight": "model-00001-of-00002.safetensors",
520
+ "siglip2.vision_model.encoder.layers.15.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
521
+ "siglip2.vision_model.encoder.layers.15.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
522
+ "siglip2.vision_model.encoder.layers.15.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
523
+ "siglip2.vision_model.encoder.layers.15.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
524
+ "siglip2.vision_model.encoder.layers.15.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
525
+ "siglip2.vision_model.encoder.layers.15.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
526
+ "siglip2.vision_model.encoder.layers.15.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
527
+ "siglip2.vision_model.encoder.layers.15.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
528
+ "siglip2.vision_model.encoder.layers.16.layer_norm1.bias": "model-00001-of-00002.safetensors",
529
+ "siglip2.vision_model.encoder.layers.16.layer_norm1.weight": "model-00001-of-00002.safetensors",
530
+ "siglip2.vision_model.encoder.layers.16.layer_norm2.bias": "model-00001-of-00002.safetensors",
531
+ "siglip2.vision_model.encoder.layers.16.layer_norm2.weight": "model-00001-of-00002.safetensors",
532
+ "siglip2.vision_model.encoder.layers.16.mlp.fc1.bias": "model-00001-of-00002.safetensors",
533
+ "siglip2.vision_model.encoder.layers.16.mlp.fc1.weight": "model-00001-of-00002.safetensors",
534
+ "siglip2.vision_model.encoder.layers.16.mlp.fc2.bias": "model-00001-of-00002.safetensors",
535
+ "siglip2.vision_model.encoder.layers.16.mlp.fc2.weight": "model-00001-of-00002.safetensors",
536
+ "siglip2.vision_model.encoder.layers.16.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
537
+ "siglip2.vision_model.encoder.layers.16.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
538
+ "siglip2.vision_model.encoder.layers.16.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
539
+ "siglip2.vision_model.encoder.layers.16.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
540
+ "siglip2.vision_model.encoder.layers.16.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
541
+ "siglip2.vision_model.encoder.layers.16.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
542
+ "siglip2.vision_model.encoder.layers.16.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
543
+ "siglip2.vision_model.encoder.layers.16.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
544
+ "siglip2.vision_model.encoder.layers.17.layer_norm1.bias": "model-00001-of-00002.safetensors",
545
+ "siglip2.vision_model.encoder.layers.17.layer_norm1.weight": "model-00001-of-00002.safetensors",
546
+ "siglip2.vision_model.encoder.layers.17.layer_norm2.bias": "model-00001-of-00002.safetensors",
547
+ "siglip2.vision_model.encoder.layers.17.layer_norm2.weight": "model-00001-of-00002.safetensors",
548
+ "siglip2.vision_model.encoder.layers.17.mlp.fc1.bias": "model-00001-of-00002.safetensors",
549
+ "siglip2.vision_model.encoder.layers.17.mlp.fc1.weight": "model-00001-of-00002.safetensors",
550
+ "siglip2.vision_model.encoder.layers.17.mlp.fc2.bias": "model-00001-of-00002.safetensors",
551
+ "siglip2.vision_model.encoder.layers.17.mlp.fc2.weight": "model-00001-of-00002.safetensors",
552
+ "siglip2.vision_model.encoder.layers.17.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
553
+ "siglip2.vision_model.encoder.layers.17.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
554
+ "siglip2.vision_model.encoder.layers.17.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
555
+ "siglip2.vision_model.encoder.layers.17.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
556
+ "siglip2.vision_model.encoder.layers.17.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
557
+ "siglip2.vision_model.encoder.layers.17.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
558
+ "siglip2.vision_model.encoder.layers.17.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
559
+ "siglip2.vision_model.encoder.layers.17.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
560
+ "siglip2.vision_model.encoder.layers.18.layer_norm1.bias": "model-00001-of-00002.safetensors",
561
+ "siglip2.vision_model.encoder.layers.18.layer_norm1.weight": "model-00001-of-00002.safetensors",
562
+ "siglip2.vision_model.encoder.layers.18.layer_norm2.bias": "model-00001-of-00002.safetensors",
563
+ "siglip2.vision_model.encoder.layers.18.layer_norm2.weight": "model-00001-of-00002.safetensors",
564
+ "siglip2.vision_model.encoder.layers.18.mlp.fc1.bias": "model-00001-of-00002.safetensors",
565
+ "siglip2.vision_model.encoder.layers.18.mlp.fc1.weight": "model-00001-of-00002.safetensors",
566
+ "siglip2.vision_model.encoder.layers.18.mlp.fc2.bias": "model-00001-of-00002.safetensors",
567
+ "siglip2.vision_model.encoder.layers.18.mlp.fc2.weight": "model-00001-of-00002.safetensors",
568
+ "siglip2.vision_model.encoder.layers.18.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
569
+ "siglip2.vision_model.encoder.layers.18.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
570
+ "siglip2.vision_model.encoder.layers.18.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
571
+ "siglip2.vision_model.encoder.layers.18.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
572
+ "siglip2.vision_model.encoder.layers.18.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
573
+ "siglip2.vision_model.encoder.layers.18.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
574
+ "siglip2.vision_model.encoder.layers.18.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
575
+ "siglip2.vision_model.encoder.layers.18.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
576
+ "siglip2.vision_model.encoder.layers.19.layer_norm1.bias": "model-00001-of-00002.safetensors",
577
+ "siglip2.vision_model.encoder.layers.19.layer_norm1.weight": "model-00001-of-00002.safetensors",
578
+ "siglip2.vision_model.encoder.layers.19.layer_norm2.bias": "model-00001-of-00002.safetensors",
579
+ "siglip2.vision_model.encoder.layers.19.layer_norm2.weight": "model-00001-of-00002.safetensors",
580
+ "siglip2.vision_model.encoder.layers.19.mlp.fc1.bias": "model-00001-of-00002.safetensors",
581
+ "siglip2.vision_model.encoder.layers.19.mlp.fc1.weight": "model-00001-of-00002.safetensors",
582
+ "siglip2.vision_model.encoder.layers.19.mlp.fc2.bias": "model-00001-of-00002.safetensors",
583
+ "siglip2.vision_model.encoder.layers.19.mlp.fc2.weight": "model-00001-of-00002.safetensors",
584
+ "siglip2.vision_model.encoder.layers.19.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
585
+ "siglip2.vision_model.encoder.layers.19.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
586
+ "siglip2.vision_model.encoder.layers.19.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
587
+ "siglip2.vision_model.encoder.layers.19.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
588
+ "siglip2.vision_model.encoder.layers.19.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
589
+ "siglip2.vision_model.encoder.layers.19.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
590
+ "siglip2.vision_model.encoder.layers.19.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
591
+ "siglip2.vision_model.encoder.layers.19.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
592
+ "siglip2.vision_model.encoder.layers.2.layer_norm1.bias": "model-00001-of-00002.safetensors",
593
+ "siglip2.vision_model.encoder.layers.2.layer_norm1.weight": "model-00001-of-00002.safetensors",
594
+ "siglip2.vision_model.encoder.layers.2.layer_norm2.bias": "model-00001-of-00002.safetensors",
595
+ "siglip2.vision_model.encoder.layers.2.layer_norm2.weight": "model-00001-of-00002.safetensors",
596
+ "siglip2.vision_model.encoder.layers.2.mlp.fc1.bias": "model-00001-of-00002.safetensors",
597
+ "siglip2.vision_model.encoder.layers.2.mlp.fc1.weight": "model-00001-of-00002.safetensors",
598
+ "siglip2.vision_model.encoder.layers.2.mlp.fc2.bias": "model-00001-of-00002.safetensors",
599
+ "siglip2.vision_model.encoder.layers.2.mlp.fc2.weight": "model-00001-of-00002.safetensors",
600
+ "siglip2.vision_model.encoder.layers.2.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
601
+ "siglip2.vision_model.encoder.layers.2.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
602
+ "siglip2.vision_model.encoder.layers.2.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
603
+ "siglip2.vision_model.encoder.layers.2.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
604
+ "siglip2.vision_model.encoder.layers.2.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
605
+ "siglip2.vision_model.encoder.layers.2.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
606
+ "siglip2.vision_model.encoder.layers.2.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
607
+ "siglip2.vision_model.encoder.layers.2.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
608
+ "siglip2.vision_model.encoder.layers.20.layer_norm1.bias": "model-00001-of-00002.safetensors",
609
+ "siglip2.vision_model.encoder.layers.20.layer_norm1.weight": "model-00001-of-00002.safetensors",
610
+ "siglip2.vision_model.encoder.layers.20.layer_norm2.bias": "model-00001-of-00002.safetensors",
611
+ "siglip2.vision_model.encoder.layers.20.layer_norm2.weight": "model-00001-of-00002.safetensors",
612
+ "siglip2.vision_model.encoder.layers.20.mlp.fc1.bias": "model-00001-of-00002.safetensors",
613
+ "siglip2.vision_model.encoder.layers.20.mlp.fc1.weight": "model-00001-of-00002.safetensors",
614
+ "siglip2.vision_model.encoder.layers.20.mlp.fc2.bias": "model-00001-of-00002.safetensors",
615
+ "siglip2.vision_model.encoder.layers.20.mlp.fc2.weight": "model-00001-of-00002.safetensors",
616
+ "siglip2.vision_model.encoder.layers.20.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
617
+ "siglip2.vision_model.encoder.layers.20.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
618
+ "siglip2.vision_model.encoder.layers.20.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
619
+ "siglip2.vision_model.encoder.layers.20.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
620
+ "siglip2.vision_model.encoder.layers.20.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
621
+ "siglip2.vision_model.encoder.layers.20.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
622
+ "siglip2.vision_model.encoder.layers.20.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
623
+ "siglip2.vision_model.encoder.layers.20.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
624
+ "siglip2.vision_model.encoder.layers.21.layer_norm1.bias": "model-00001-of-00002.safetensors",
625
+ "siglip2.vision_model.encoder.layers.21.layer_norm1.weight": "model-00001-of-00002.safetensors",
626
+ "siglip2.vision_model.encoder.layers.21.layer_norm2.bias": "model-00001-of-00002.safetensors",
627
+ "siglip2.vision_model.encoder.layers.21.layer_norm2.weight": "model-00001-of-00002.safetensors",
628
+ "siglip2.vision_model.encoder.layers.21.mlp.fc1.bias": "model-00001-of-00002.safetensors",
629
+ "siglip2.vision_model.encoder.layers.21.mlp.fc1.weight": "model-00001-of-00002.safetensors",
630
+ "siglip2.vision_model.encoder.layers.21.mlp.fc2.bias": "model-00001-of-00002.safetensors",
631
+ "siglip2.vision_model.encoder.layers.21.mlp.fc2.weight": "model-00001-of-00002.safetensors",
632
+ "siglip2.vision_model.encoder.layers.21.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
633
+ "siglip2.vision_model.encoder.layers.21.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
634
+ "siglip2.vision_model.encoder.layers.21.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
635
+ "siglip2.vision_model.encoder.layers.21.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
636
+ "siglip2.vision_model.encoder.layers.21.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
637
+ "siglip2.vision_model.encoder.layers.21.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
638
+ "siglip2.vision_model.encoder.layers.21.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
639
+ "siglip2.vision_model.encoder.layers.21.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
640
+ "siglip2.vision_model.encoder.layers.22.layer_norm1.bias": "model-00001-of-00002.safetensors",
641
+ "siglip2.vision_model.encoder.layers.22.layer_norm1.weight": "model-00001-of-00002.safetensors",
642
+ "siglip2.vision_model.encoder.layers.22.layer_norm2.bias": "model-00001-of-00002.safetensors",
643
+ "siglip2.vision_model.encoder.layers.22.layer_norm2.weight": "model-00001-of-00002.safetensors",
644
+ "siglip2.vision_model.encoder.layers.22.mlp.fc1.bias": "model-00001-of-00002.safetensors",
645
+ "siglip2.vision_model.encoder.layers.22.mlp.fc1.weight": "model-00001-of-00002.safetensors",
646
+ "siglip2.vision_model.encoder.layers.22.mlp.fc2.bias": "model-00001-of-00002.safetensors",
647
+ "siglip2.vision_model.encoder.layers.22.mlp.fc2.weight": "model-00001-of-00002.safetensors",
648
+ "siglip2.vision_model.encoder.layers.22.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
649
+ "siglip2.vision_model.encoder.layers.22.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
650
+ "siglip2.vision_model.encoder.layers.22.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
651
+ "siglip2.vision_model.encoder.layers.22.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
652
+ "siglip2.vision_model.encoder.layers.22.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
653
+ "siglip2.vision_model.encoder.layers.22.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
654
+ "siglip2.vision_model.encoder.layers.22.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
655
+ "siglip2.vision_model.encoder.layers.22.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
656
+ "siglip2.vision_model.encoder.layers.23.layer_norm1.bias": "model-00001-of-00002.safetensors",
657
+ "siglip2.vision_model.encoder.layers.23.layer_norm1.weight": "model-00001-of-00002.safetensors",
658
+ "siglip2.vision_model.encoder.layers.23.layer_norm2.bias": "model-00001-of-00002.safetensors",
659
+ "siglip2.vision_model.encoder.layers.23.layer_norm2.weight": "model-00001-of-00002.safetensors",
660
+ "siglip2.vision_model.encoder.layers.23.mlp.fc1.bias": "model-00001-of-00002.safetensors",
661
+ "siglip2.vision_model.encoder.layers.23.mlp.fc1.weight": "model-00001-of-00002.safetensors",
662
+ "siglip2.vision_model.encoder.layers.23.mlp.fc2.bias": "model-00001-of-00002.safetensors",
663
+ "siglip2.vision_model.encoder.layers.23.mlp.fc2.weight": "model-00001-of-00002.safetensors",
664
+ "siglip2.vision_model.encoder.layers.23.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
665
+ "siglip2.vision_model.encoder.layers.23.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
666
+ "siglip2.vision_model.encoder.layers.23.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
667
+ "siglip2.vision_model.encoder.layers.23.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
668
+ "siglip2.vision_model.encoder.layers.23.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
669
+ "siglip2.vision_model.encoder.layers.23.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
670
+ "siglip2.vision_model.encoder.layers.23.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
671
+ "siglip2.vision_model.encoder.layers.23.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
672
+ "siglip2.vision_model.encoder.layers.24.layer_norm1.bias": "model-00001-of-00002.safetensors",
673
+ "siglip2.vision_model.encoder.layers.24.layer_norm1.weight": "model-00001-of-00002.safetensors",
674
+ "siglip2.vision_model.encoder.layers.24.layer_norm2.bias": "model-00001-of-00002.safetensors",
675
+ "siglip2.vision_model.encoder.layers.24.layer_norm2.weight": "model-00001-of-00002.safetensors",
676
+ "siglip2.vision_model.encoder.layers.24.mlp.fc1.bias": "model-00001-of-00002.safetensors",
677
+ "siglip2.vision_model.encoder.layers.24.mlp.fc1.weight": "model-00001-of-00002.safetensors",
678
+ "siglip2.vision_model.encoder.layers.24.mlp.fc2.bias": "model-00001-of-00002.safetensors",
679
+ "siglip2.vision_model.encoder.layers.24.mlp.fc2.weight": "model-00001-of-00002.safetensors",
680
+ "siglip2.vision_model.encoder.layers.24.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
681
+ "siglip2.vision_model.encoder.layers.24.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
682
+ "siglip2.vision_model.encoder.layers.24.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
683
+ "siglip2.vision_model.encoder.layers.24.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
684
+ "siglip2.vision_model.encoder.layers.24.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
685
+ "siglip2.vision_model.encoder.layers.24.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
686
+ "siglip2.vision_model.encoder.layers.24.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
687
+ "siglip2.vision_model.encoder.layers.24.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
688
+ "siglip2.vision_model.encoder.layers.25.layer_norm1.bias": "model-00001-of-00002.safetensors",
689
+ "siglip2.vision_model.encoder.layers.25.layer_norm1.weight": "model-00001-of-00002.safetensors",
690
+ "siglip2.vision_model.encoder.layers.25.layer_norm2.bias": "model-00001-of-00002.safetensors",
691
+ "siglip2.vision_model.encoder.layers.25.layer_norm2.weight": "model-00001-of-00002.safetensors",
692
+ "siglip2.vision_model.encoder.layers.25.mlp.fc1.bias": "model-00001-of-00002.safetensors",
693
+ "siglip2.vision_model.encoder.layers.25.mlp.fc1.weight": "model-00001-of-00002.safetensors",
694
+ "siglip2.vision_model.encoder.layers.25.mlp.fc2.bias": "model-00001-of-00002.safetensors",
695
+ "siglip2.vision_model.encoder.layers.25.mlp.fc2.weight": "model-00001-of-00002.safetensors",
696
+ "siglip2.vision_model.encoder.layers.25.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
697
+ "siglip2.vision_model.encoder.layers.25.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
698
+ "siglip2.vision_model.encoder.layers.25.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
699
+ "siglip2.vision_model.encoder.layers.25.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
700
+ "siglip2.vision_model.encoder.layers.25.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
701
+ "siglip2.vision_model.encoder.layers.25.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
702
+ "siglip2.vision_model.encoder.layers.25.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
703
+ "siglip2.vision_model.encoder.layers.25.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
704
+ "siglip2.vision_model.encoder.layers.26.layer_norm1.bias": "model-00001-of-00002.safetensors",
705
+ "siglip2.vision_model.encoder.layers.26.layer_norm1.weight": "model-00001-of-00002.safetensors",
706
+ "siglip2.vision_model.encoder.layers.26.layer_norm2.bias": "model-00001-of-00002.safetensors",
707
+ "siglip2.vision_model.encoder.layers.26.layer_norm2.weight": "model-00001-of-00002.safetensors",
708
+ "siglip2.vision_model.encoder.layers.26.mlp.fc1.bias": "model-00001-of-00002.safetensors",
709
+ "siglip2.vision_model.encoder.layers.26.mlp.fc1.weight": "model-00001-of-00002.safetensors",
710
+ "siglip2.vision_model.encoder.layers.26.mlp.fc2.bias": "model-00001-of-00002.safetensors",
711
+ "siglip2.vision_model.encoder.layers.26.mlp.fc2.weight": "model-00001-of-00002.safetensors",
712
+ "siglip2.vision_model.encoder.layers.26.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
713
+ "siglip2.vision_model.encoder.layers.26.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
714
+ "siglip2.vision_model.encoder.layers.26.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
715
+ "siglip2.vision_model.encoder.layers.26.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
716
+ "siglip2.vision_model.encoder.layers.26.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
717
+ "siglip2.vision_model.encoder.layers.26.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
718
+ "siglip2.vision_model.encoder.layers.26.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
719
+ "siglip2.vision_model.encoder.layers.26.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
720
+ "siglip2.vision_model.encoder.layers.3.layer_norm1.bias": "model-00001-of-00002.safetensors",
721
+ "siglip2.vision_model.encoder.layers.3.layer_norm1.weight": "model-00001-of-00002.safetensors",
722
+ "siglip2.vision_model.encoder.layers.3.layer_norm2.bias": "model-00001-of-00002.safetensors",
723
+ "siglip2.vision_model.encoder.layers.3.layer_norm2.weight": "model-00001-of-00002.safetensors",
724
+ "siglip2.vision_model.encoder.layers.3.mlp.fc1.bias": "model-00001-of-00002.safetensors",
725
+ "siglip2.vision_model.encoder.layers.3.mlp.fc1.weight": "model-00001-of-00002.safetensors",
726
+ "siglip2.vision_model.encoder.layers.3.mlp.fc2.bias": "model-00001-of-00002.safetensors",
727
+ "siglip2.vision_model.encoder.layers.3.mlp.fc2.weight": "model-00001-of-00002.safetensors",
728
+ "siglip2.vision_model.encoder.layers.3.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
729
+ "siglip2.vision_model.encoder.layers.3.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
730
+ "siglip2.vision_model.encoder.layers.3.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
731
+ "siglip2.vision_model.encoder.layers.3.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
732
+ "siglip2.vision_model.encoder.layers.3.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
733
+ "siglip2.vision_model.encoder.layers.3.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
734
+ "siglip2.vision_model.encoder.layers.3.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
735
+ "siglip2.vision_model.encoder.layers.3.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
736
+ "siglip2.vision_model.encoder.layers.4.layer_norm1.bias": "model-00001-of-00002.safetensors",
737
+ "siglip2.vision_model.encoder.layers.4.layer_norm1.weight": "model-00001-of-00002.safetensors",
738
+ "siglip2.vision_model.encoder.layers.4.layer_norm2.bias": "model-00001-of-00002.safetensors",
739
+ "siglip2.vision_model.encoder.layers.4.layer_norm2.weight": "model-00001-of-00002.safetensors",
740
+ "siglip2.vision_model.encoder.layers.4.mlp.fc1.bias": "model-00001-of-00002.safetensors",
741
+ "siglip2.vision_model.encoder.layers.4.mlp.fc1.weight": "model-00001-of-00002.safetensors",
742
+ "siglip2.vision_model.encoder.layers.4.mlp.fc2.bias": "model-00001-of-00002.safetensors",
743
+ "siglip2.vision_model.encoder.layers.4.mlp.fc2.weight": "model-00001-of-00002.safetensors",
744
+ "siglip2.vision_model.encoder.layers.4.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
745
+ "siglip2.vision_model.encoder.layers.4.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
746
+ "siglip2.vision_model.encoder.layers.4.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
747
+ "siglip2.vision_model.encoder.layers.4.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
748
+ "siglip2.vision_model.encoder.layers.4.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
749
+ "siglip2.vision_model.encoder.layers.4.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
750
+ "siglip2.vision_model.encoder.layers.4.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
751
+ "siglip2.vision_model.encoder.layers.4.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
752
+ "siglip2.vision_model.encoder.layers.5.layer_norm1.bias": "model-00001-of-00002.safetensors",
753
+ "siglip2.vision_model.encoder.layers.5.layer_norm1.weight": "model-00001-of-00002.safetensors",
754
+ "siglip2.vision_model.encoder.layers.5.layer_norm2.bias": "model-00001-of-00002.safetensors",
755
+ "siglip2.vision_model.encoder.layers.5.layer_norm2.weight": "model-00001-of-00002.safetensors",
756
+ "siglip2.vision_model.encoder.layers.5.mlp.fc1.bias": "model-00001-of-00002.safetensors",
757
+ "siglip2.vision_model.encoder.layers.5.mlp.fc1.weight": "model-00001-of-00002.safetensors",
758
+ "siglip2.vision_model.encoder.layers.5.mlp.fc2.bias": "model-00001-of-00002.safetensors",
759
+ "siglip2.vision_model.encoder.layers.5.mlp.fc2.weight": "model-00001-of-00002.safetensors",
760
+ "siglip2.vision_model.encoder.layers.5.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
761
+ "siglip2.vision_model.encoder.layers.5.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
762
+ "siglip2.vision_model.encoder.layers.5.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
763
+ "siglip2.vision_model.encoder.layers.5.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
764
+ "siglip2.vision_model.encoder.layers.5.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
765
+ "siglip2.vision_model.encoder.layers.5.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
766
+ "siglip2.vision_model.encoder.layers.5.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
767
+ "siglip2.vision_model.encoder.layers.5.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
768
+ "siglip2.vision_model.encoder.layers.6.layer_norm1.bias": "model-00001-of-00002.safetensors",
769
+ "siglip2.vision_model.encoder.layers.6.layer_norm1.weight": "model-00001-of-00002.safetensors",
770
+ "siglip2.vision_model.encoder.layers.6.layer_norm2.bias": "model-00001-of-00002.safetensors",
771
+ "siglip2.vision_model.encoder.layers.6.layer_norm2.weight": "model-00001-of-00002.safetensors",
772
+ "siglip2.vision_model.encoder.layers.6.mlp.fc1.bias": "model-00001-of-00002.safetensors",
773
+ "siglip2.vision_model.encoder.layers.6.mlp.fc1.weight": "model-00001-of-00002.safetensors",
774
+ "siglip2.vision_model.encoder.layers.6.mlp.fc2.bias": "model-00001-of-00002.safetensors",
775
+ "siglip2.vision_model.encoder.layers.6.mlp.fc2.weight": "model-00001-of-00002.safetensors",
776
+ "siglip2.vision_model.encoder.layers.6.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
777
+ "siglip2.vision_model.encoder.layers.6.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
778
+ "siglip2.vision_model.encoder.layers.6.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
779
+ "siglip2.vision_model.encoder.layers.6.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
780
+ "siglip2.vision_model.encoder.layers.6.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
781
+ "siglip2.vision_model.encoder.layers.6.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
782
+ "siglip2.vision_model.encoder.layers.6.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
783
+ "siglip2.vision_model.encoder.layers.6.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
784
+ "siglip2.vision_model.encoder.layers.7.layer_norm1.bias": "model-00001-of-00002.safetensors",
785
+ "siglip2.vision_model.encoder.layers.7.layer_norm1.weight": "model-00001-of-00002.safetensors",
786
+ "siglip2.vision_model.encoder.layers.7.layer_norm2.bias": "model-00001-of-00002.safetensors",
787
+ "siglip2.vision_model.encoder.layers.7.layer_norm2.weight": "model-00001-of-00002.safetensors",
788
+ "siglip2.vision_model.encoder.layers.7.mlp.fc1.bias": "model-00001-of-00002.safetensors",
789
+ "siglip2.vision_model.encoder.layers.7.mlp.fc1.weight": "model-00001-of-00002.safetensors",
790
+ "siglip2.vision_model.encoder.layers.7.mlp.fc2.bias": "model-00001-of-00002.safetensors",
791
+ "siglip2.vision_model.encoder.layers.7.mlp.fc2.weight": "model-00001-of-00002.safetensors",
792
+ "siglip2.vision_model.encoder.layers.7.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
793
+ "siglip2.vision_model.encoder.layers.7.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
794
+ "siglip2.vision_model.encoder.layers.7.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
795
+ "siglip2.vision_model.encoder.layers.7.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
796
+ "siglip2.vision_model.encoder.layers.7.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
797
+ "siglip2.vision_model.encoder.layers.7.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
798
+ "siglip2.vision_model.encoder.layers.7.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
799
+ "siglip2.vision_model.encoder.layers.7.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
800
+ "siglip2.vision_model.encoder.layers.8.layer_norm1.bias": "model-00001-of-00002.safetensors",
801
+ "siglip2.vision_model.encoder.layers.8.layer_norm1.weight": "model-00001-of-00002.safetensors",
802
+ "siglip2.vision_model.encoder.layers.8.layer_norm2.bias": "model-00001-of-00002.safetensors",
803
+ "siglip2.vision_model.encoder.layers.8.layer_norm2.weight": "model-00001-of-00002.safetensors",
804
+ "siglip2.vision_model.encoder.layers.8.mlp.fc1.bias": "model-00001-of-00002.safetensors",
805
+ "siglip2.vision_model.encoder.layers.8.mlp.fc1.weight": "model-00001-of-00002.safetensors",
806
+ "siglip2.vision_model.encoder.layers.8.mlp.fc2.bias": "model-00001-of-00002.safetensors",
807
+ "siglip2.vision_model.encoder.layers.8.mlp.fc2.weight": "model-00001-of-00002.safetensors",
808
+ "siglip2.vision_model.encoder.layers.8.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
809
+ "siglip2.vision_model.encoder.layers.8.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
810
+ "siglip2.vision_model.encoder.layers.8.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
811
+ "siglip2.vision_model.encoder.layers.8.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
812
+ "siglip2.vision_model.encoder.layers.8.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
813
+ "siglip2.vision_model.encoder.layers.8.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
814
+ "siglip2.vision_model.encoder.layers.8.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
815
+ "siglip2.vision_model.encoder.layers.8.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
816
+ "siglip2.vision_model.encoder.layers.9.layer_norm1.bias": "model-00001-of-00002.safetensors",
817
+ "siglip2.vision_model.encoder.layers.9.layer_norm1.weight": "model-00001-of-00002.safetensors",
818
+ "siglip2.vision_model.encoder.layers.9.layer_norm2.bias": "model-00001-of-00002.safetensors",
819
+ "siglip2.vision_model.encoder.layers.9.layer_norm2.weight": "model-00001-of-00002.safetensors",
820
+ "siglip2.vision_model.encoder.layers.9.mlp.fc1.bias": "model-00001-of-00002.safetensors",
821
+ "siglip2.vision_model.encoder.layers.9.mlp.fc1.weight": "model-00001-of-00002.safetensors",
822
+ "siglip2.vision_model.encoder.layers.9.mlp.fc2.bias": "model-00001-of-00002.safetensors",
823
+ "siglip2.vision_model.encoder.layers.9.mlp.fc2.weight": "model-00001-of-00002.safetensors",
824
+ "siglip2.vision_model.encoder.layers.9.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
825
+ "siglip2.vision_model.encoder.layers.9.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
826
+ "siglip2.vision_model.encoder.layers.9.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
827
+ "siglip2.vision_model.encoder.layers.9.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
828
+ "siglip2.vision_model.encoder.layers.9.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
829
+ "siglip2.vision_model.encoder.layers.9.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
830
+ "siglip2.vision_model.encoder.layers.9.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
831
+ "siglip2.vision_model.encoder.layers.9.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
832
+ "siglip2.vision_model.post_layernorm.bias": "model-00001-of-00002.safetensors",
833
+ "siglip2.vision_model.post_layernorm.weight": "model-00001-of-00002.safetensors"
834
+ }
835
+ }
modeling_siglip2.py ADDED
@@ -0,0 +1,1594 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable, Optional, Tuple, Union
5
+
6
+ import os
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
12
+ from torch.nn.init import _calculate_fan_in_and_fan_out
13
+
14
+ from transformers.activations import ACT2FN
15
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
16
+ # from transformers.modeling_layers import GradientCheckpointingLayer
17
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
18
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
19
+ from transformers.utils import (
20
+ ModelOutput,
21
+ add_start_docstrings,
22
+ add_start_docstrings_to_model_forward,
23
+ can_return_tuple,
24
+ logging,
25
+ replace_return_docstrings,
26
+ is_flash_attn_2_available,
27
+ is_flash_attn_greater_or_equal_2_10,
28
+ )
29
+
30
+ from .configuration_siglip2 import Siglip2Config, Siglip2TextConfig, Siglip2VisionConfig
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+ _CONFIG_FOR_DOC = "Siglip2Config"
36
+
37
+ is_aiter_available = False
38
+
39
+ if is_flash_attn_2_available():
40
+ try:
41
+ from aiter import flash_attn_varlen_func
42
+ is_aiter_available = True
43
+ except ImportError:
44
+ from flash_attn import flash_attn_varlen_func
45
+ from flash_attn.layers.rotary import apply_rotary_emb
46
+
47
+ else:
48
+ flash_attn_varlen_func = None
49
+ apply_rotary_emb = None
50
+
51
+
52
+ if is_flash_attn_2_available():
53
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
54
+ else:
55
+ flash_attn_varlen_func = None
56
+
57
+ @dataclass
58
+ class Siglip2VisionOutput(ModelOutput):
59
+ """
60
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
61
+
62
+ Args:
63
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)`
64
+ *optional* returned when model is initialized with `with_projection=True`):
65
+ The image embeddings obtained by applying the projection layer to the pooler_output.
66
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
67
+ Sequence of hidden-states at the output of the last layer of the model.
68
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`
69
+ is passed or when `config.output_hidden_states=True`):
70
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
71
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
72
+
73
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
74
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or
75
+ when `config.output_attentions=True`):
76
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
77
+ sequence_length)`.
78
+
79
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
80
+ heads.
81
+ """
82
+
83
+ image_embeds: Optional[torch.FloatTensor] = None
84
+ last_hidden_state: Optional[torch.FloatTensor] = None
85
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
86
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
87
+
88
+
89
+ @dataclass
90
+ class Siglip2TextOutput(ModelOutput):
91
+ """
92
+ Base class for text model's outputs that also contains a pooling of the last hidden states.
93
+
94
+ Args:
95
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)`
96
+ *optional* returned when model is initialized with `with_projection=True`):
97
+ The text embeddings obtained by applying the projection layer to the pooler_output.
98
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
99
+ Sequence of hidden-states at the output of the last layer of the model.
100
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned
101
+ when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
102
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
103
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
104
+
105
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
106
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed
107
+ or when `config.output_attentions=True`):
108
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
109
+ sequence_length)`.
110
+
111
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
112
+ heads.
113
+ """
114
+
115
+ text_embeds: Optional[torch.FloatTensor] = None
116
+ last_hidden_state: Optional[torch.FloatTensor] = None
117
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
118
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
119
+
120
+
121
+ @dataclass
122
+ class Siglip2Output(ModelOutput):
123
+ """
124
+ Args:
125
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
126
+ Contrastive loss for image-text similarity.
127
+ logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
128
+ The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
129
+ similarity scores.
130
+ logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
131
+ The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
132
+ similarity scores.
133
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
134
+ The text embeddings obtained by applying the projection layer to the pooled output of [`Siglip2TextModel`].
135
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
136
+ The image embeddings obtained by applying the projection layer to
137
+ the pooled output of [`Siglip2VisionModel`].
138
+ text_model_output (`BaseModelOutputWithPooling`):
139
+ The output of the [`Siglip2TextModel`].
140
+ vision_model_output (`BaseModelOutputWithPooling`):
141
+ The output of the [`Siglip2VisionModel`].
142
+ """
143
+
144
+ loss: Optional[torch.FloatTensor] = None
145
+ logits_per_image: Optional[torch.FloatTensor] = None
146
+ logits_per_text: Optional[torch.FloatTensor] = None
147
+ text_embeds: Optional[torch.FloatTensor] = None
148
+ image_embeds: Optional[torch.FloatTensor] = None
149
+ text_model_output: BaseModelOutputWithPooling = None
150
+ vision_model_output: BaseModelOutputWithPooling = None
151
+
152
+ def to_tuple(self) -> Tuple[Any]:
153
+ return tuple(
154
+ self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
155
+ for k in self.keys()
156
+ )
157
+
158
+ class Siglip2VisionEmbeddings(nn.Module):
159
+ def __init__(self, config: Siglip2VisionConfig):
160
+ super().__init__()
161
+ self.config = config
162
+ self.embed_dim = config.hidden_size
163
+ self.patch_size = config.patch_size
164
+
165
+ if hasattr(config, 'in_features') and config.in_features > 0:
166
+ self.in_features = config.in_features
167
+ else:
168
+ self.in_features = config.num_channels * self.patch_size * self.patch_size
169
+
170
+ self.patch_embedding = nn.Linear(
171
+ in_features=self.in_features,
172
+ out_features=self.embed_dim,
173
+ )
174
+
175
+ self.num_patches = config.num_patches
176
+ self.position_embedding_size = int(self.num_patches**0.5)
177
+ self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
178
+
179
+ @staticmethod
180
+ def resize_positional_embeddings(
181
+ positional_embeddings: torch.Tensor,
182
+ spatial_shapes: torch.LongTensor,
183
+ max_length: int,
184
+ ) -> torch.Tensor:
185
+ """
186
+ Resize positional embeddings to image-specific size and pad to a fixed size.
187
+
188
+ Args:
189
+ positional_embeddings (`torch.Tensor`):
190
+ Position embeddings of shape (height, width, embed_dim)
191
+ spatial_shapes (`torch.LongTensor`):
192
+ Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
193
+ max_length (`int`):
194
+ Maximum length of the positional embeddings to pad resized positional embeddings to
195
+
196
+ Returns:
197
+ `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim)
198
+ """
199
+ batch_size = spatial_shapes.shape[0]
200
+ embed_dim = positional_embeddings.shape[-1]
201
+ source_dtype = positional_embeddings.dtype
202
+
203
+ resulted_positional_embeddings = torch.empty(
204
+ (batch_size, max_length, embed_dim),
205
+ device=positional_embeddings.device,
206
+ dtype=source_dtype,
207
+ )
208
+
209
+ positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0)
210
+ if positional_embeddings.device.type == "cpu":
211
+ positional_embeddings = positional_embeddings.to(torch.float32)
212
+ for i in range(batch_size):
213
+ height, width = spatial_shapes[i]
214
+ resized_embeddings = F.interpolate(
215
+ positional_embeddings,
216
+ size=(height, width),
217
+ mode="bilinear",
218
+ align_corners=False,
219
+ antialias=True,
220
+ )
221
+ resized_embeddings = resized_embeddings.reshape(embed_dim, height * width).transpose(0, 1)
222
+ resized_embeddings = resized_embeddings.to(source_dtype)
223
+
224
+ resulted_positional_embeddings[i, : height * width] = resized_embeddings
225
+ resulted_positional_embeddings[i, height * width :] = resized_embeddings[0]
226
+
227
+ return resulted_positional_embeddings
228
+
229
+
230
+ def forward(self, pixel_values: torch.FloatTensor, spatial_shapes: torch.LongTensor) -> torch.Tensor:
231
+ """
232
+ Args:
233
+ pixel_values (`torch.FloatTensor`):
234
+ Pixel values of shape (batch_size, max_num_patches, num_channels * patch_size * patch_size)
235
+ spatial_shapes (`List[Tuple[int, int]]`):
236
+ Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
237
+ """
238
+
239
+ target_dtype = self.patch_embedding.weight.dtype
240
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
241
+ positional_embeddings = self.position_embedding.weight.reshape(
242
+ self.position_embedding_size, self.position_embedding_size, -1
243
+ )
244
+
245
+ resized_positional_embeddings = self.resize_positional_embeddings(
246
+ positional_embeddings, spatial_shapes, max_length=pixel_values.shape[1]
247
+ )
248
+ embeddings = patch_embeds + resized_positional_embeddings
249
+ return embeddings
250
+
251
+
252
+ class Siglip2VisionEmbeddingsWoPos(nn.Module):
253
+ def __init__(self, config: Siglip2VisionConfig):
254
+ super().__init__()
255
+ self.config = config
256
+ self.embed_dim = config.hidden_size
257
+ self.patch_size = config.patch_size
258
+
259
+ if hasattr(config, 'in_features') and config.in_features > 0:
260
+ self.in_features = config.in_features
261
+ else:
262
+ self.in_features = config.num_channels * self.patch_size * self.patch_size
263
+
264
+ self.patch_embedding = nn.Linear(
265
+ in_features=self.in_features,
266
+ out_features=self.embed_dim,
267
+ )
268
+
269
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
270
+ target_dtype = self.patch_embedding.weight.dtype
271
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
272
+ patch_embeds = patch_embeds.view(-1, self.embed_dim)
273
+ return patch_embeds
274
+
275
+
276
+ def eager_attention_forward(
277
+ module: nn.Module,
278
+ query: torch.Tensor,
279
+ key: torch.Tensor,
280
+ value: torch.Tensor,
281
+ attention_mask: Optional[torch.Tensor],
282
+ scaling: float,
283
+ dropout: float = 0.0,
284
+ **kwargs,
285
+ ):
286
+
287
+ attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
288
+ if attention_mask is not None:
289
+ attn_weights = attn_weights + attention_mask
290
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
291
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
292
+ attn_output = torch.matmul(attn_weights, value)
293
+ attn_output = attn_output.transpose(1, 2).contiguous()
294
+ return attn_output, attn_weights
295
+
296
+ def apply_rotary_pos_emb_flashatt(
297
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
298
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
299
+ cos = cos.chunk(2, dim=-1)[0].contiguous()
300
+ sin = sin.chunk(2, dim=-1)[0].contiguous()
301
+ q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
302
+ k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
303
+ return q_embed, k_embed
304
+
305
+ class Siglip2Attention(nn.Module):
306
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
307
+
308
+ def __init__(self, config: Union[Siglip2VisionConfig, Siglip2TextConfig]):
309
+ super().__init__()
310
+ self.config = config
311
+ self.embed_dim = config.hidden_size
312
+ self.num_heads = config.num_attention_heads
313
+ self.head_dim = self.embed_dim // self.num_heads
314
+ if self.head_dim * self.num_heads != self.embed_dim:
315
+ raise ValueError(
316
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
317
+ f" {self.num_heads})."
318
+ )
319
+ self.scale = self.head_dim**-0.5
320
+ self.dropout = config.attention_dropout
321
+ self.is_causal = False
322
+
323
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
324
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
325
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
326
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
327
+
328
+ def forward(
329
+ self,
330
+ hidden_states: torch.Tensor,
331
+ attention_mask: Optional[torch.Tensor] = None,
332
+ output_attentions: Optional[bool] = False,
333
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
334
+ """Input shape: Batch x Time x Channel"""
335
+
336
+ batch_size, seq_length, embed_dim = hidden_states.shape
337
+
338
+ queries = self.q_proj(hidden_states)
339
+ keys = self.k_proj(hidden_states)
340
+ values = self.v_proj(hidden_states)
341
+
342
+ queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
343
+ keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
344
+ values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
345
+
346
+ attention_interface: Callable = eager_attention_forward
347
+ if self.config._attn_implementation != "eager":
348
+ if self.config._attn_implementation == "sdpa" and output_attentions:
349
+ logger.warning_once(
350
+ "`torch.nn.functional.scaled_dot_product_attention` does not support"
351
+ "`output_attentions=True`. Falling back to 'eager attention. This warning"
352
+ 'can be removed using the argument `attn_implementation="eager"` when loading the model.'
353
+ )
354
+ else:
355
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
356
+
357
+ attn_output, attn_weights = attention_interface(
358
+ self,
359
+ queries,
360
+ keys,
361
+ values,
362
+ attention_mask,
363
+ is_causal=self.is_causal,
364
+ scaling=self.scale,
365
+ dropout=0.0 if not self.training else self.dropout,
366
+ )
367
+
368
+ attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
369
+ attn_output = self.out_proj(attn_output)
370
+
371
+ if not output_attentions:
372
+ attn_weights = None
373
+
374
+ return attn_output, attn_weights
375
+
376
+ class Vision_FlashAttention2(nn.Module):
377
+ def __init__(self, config: Union[Siglip2VisionConfig, Siglip2TextConfig]) -> None:
378
+ super().__init__()
379
+ dim = config.hidden_size
380
+ self.num_heads = config.num_attention_heads
381
+ self.k_proj = nn.Linear(dim, dim)
382
+ self.v_proj = nn.Linear(dim, dim)
383
+ self.q_proj = nn.Linear(dim, dim)
384
+ self.out_proj = nn.Linear(dim, dim)
385
+
386
+ def forward(
387
+ self,
388
+ hidden_states: torch.Tensor,
389
+ cu_seqlens: torch.Tensor,
390
+ rotary_pos_emb: Optional[torch.Tensor] = None,
391
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
392
+ ) -> torch.Tensor:
393
+
394
+ seq_length = hidden_states.shape[0]
395
+ q = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
396
+ k = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
397
+ v = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
398
+
399
+
400
+ if position_embeddings is None:
401
+ logger.warning_once(
402
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
403
+ "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
404
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
405
+ "removed and `position_embeddings` will be mandatory."
406
+ )
407
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
408
+ cos = emb.cos()
409
+ sin = emb.sin()
410
+ else:
411
+ cos, sin = position_embeddings
412
+
413
+ q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
414
+ q = q.squeeze(0)
415
+ k = k.squeeze(0)
416
+
417
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
418
+ if is_aiter_available:
419
+ attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens,
420
+ max_seqlen, max_seqlen, return_lse=True)[0].reshape(
421
+ seq_length, -1)
422
+ else:
423
+ attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens,
424
+ max_seqlen, max_seqlen).reshape(
425
+ seq_length, -1)
426
+ attn_output = self.out_proj(attn_output)
427
+ return attn_output, None
428
+
429
+
430
+
431
+ def rotate_half(x):
432
+ """Rotates half the hidden dims of the input."""
433
+ x1 = x[..., : x.shape[-1] // 2]
434
+ x2 = x[..., x.shape[-1] // 2 :]
435
+ return torch.cat((-x2, x1), dim=-1)
436
+
437
+
438
+ def apply_rotary_pos_emb_vision(
439
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
440
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
441
+ orig_q_dtype = q.dtype
442
+ orig_k_dtype = k.dtype
443
+ q, k = q.float(), k.float()
444
+ cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
445
+ q_embed = (q * cos) + (rotate_half(q) * sin)
446
+ k_embed = (k * cos) + (rotate_half(k) * sin)
447
+ q_embed = q_embed.to(orig_q_dtype)
448
+ k_embed = k_embed.to(orig_k_dtype)
449
+ return q_embed, k_embed
450
+
451
+
452
+ class Vision_EagerAttention(nn.Module):
453
+ def __init__(self, config: Union[Siglip2VisionConfig, Siglip2TextConfig]) -> None:
454
+ super().__init__()
455
+ dim = config.hidden_size
456
+ self.num_heads = config.num_attention_heads
457
+ self.k_proj = nn.Linear(dim, dim)
458
+ self.v_proj = nn.Linear(dim, dim)
459
+ self.q_proj = nn.Linear(dim, dim)
460
+ self.out_proj = nn.Linear(dim, dim)
461
+ self.head_dim = dim // self.num_heads
462
+
463
+ def forward(
464
+ self,
465
+ hidden_states: torch.Tensor,
466
+ cu_seqlens: torch.Tensor,
467
+ rotary_pos_emb: Optional[torch.Tensor] = None,
468
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
469
+ ) -> torch.Tensor:
470
+ seq_length = hidden_states.shape[0]
471
+ q = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
472
+ k = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
473
+ v = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
474
+
475
+ if position_embeddings is None:
476
+ logger.warning_once(
477
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
478
+ "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
479
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
480
+ "removed and `position_embeddings` will be mandatory."
481
+ )
482
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
483
+ cos = emb.cos()
484
+ sin = emb.sin()
485
+ else:
486
+ cos, sin = position_embeddings
487
+ q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
488
+
489
+ attention_mask = torch.full(
490
+ [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
491
+ )
492
+ for i in range(1, len(cu_seqlens)):
493
+ attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
494
+
495
+ q = q.transpose(0, 1)
496
+ k = k.transpose(0, 1)
497
+ v = v.transpose(0, 1)
498
+ attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
499
+ attn_weights = attn_weights + attention_mask
500
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
501
+ attn_output = torch.matmul(attn_weights, v)
502
+ attn_output = attn_output.transpose(0, 1)
503
+ attn_output = attn_output.reshape(seq_length, -1)
504
+ attn_output = self.out_proj(attn_output)
505
+ return attn_output, None
506
+
507
+
508
+ VISION_ATTENTION_CLASSES = {
509
+ 'eager': Vision_EagerAttention,
510
+ 'flash_attention_2': Vision_FlashAttention2,
511
+ }
512
+
513
+
514
+ class Siglip2MLP(nn.Module):
515
+ def __init__(self, config):
516
+ super().__init__()
517
+ self.config = config
518
+ self.activation_fn = ACT2FN[config.hidden_act]
519
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
520
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
521
+
522
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
523
+ hidden_states = self.fc1(hidden_states)
524
+ hidden_states = self.activation_fn(hidden_states)
525
+ hidden_states = self.fc2(hidden_states)
526
+ return hidden_states
527
+
528
+
529
+ # class Siglip2EncoderLayer(GradientCheckpointingLayer):
530
+ class Siglip2EncoderLayer(nn.Module):
531
+ def __init__(self, config: Union[Siglip2VisionConfig, Siglip2TextConfig]):
532
+ super().__init__()
533
+ self.embed_dim = config.hidden_size
534
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
535
+ self.self_attn = VISION_ATTENTION_CLASSES[config._attn_implementation](config=config)
536
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
537
+ self.mlp = Siglip2MLP(config)
538
+
539
+ def forward(
540
+ self,
541
+ hidden_states: torch.Tensor,
542
+ attention_mask: torch.Tensor,
543
+ cu_seqlens: torch.Tensor,
544
+ rotary_pos_emb: Optional[torch.Tensor] = None,
545
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
546
+ output_attentions: Optional[bool] = False,
547
+ ) -> Tuple[torch.FloatTensor]:
548
+ """
549
+ Args:
550
+ hidden_states (`torch.FloatTensor`):
551
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
552
+ attention_mask (`torch.FloatTensor`):
553
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements
554
+ are indicated by very large negative values.
555
+ output_attentions (`bool`, *optional*, defaults to `False`):
556
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
557
+ returned tensors for more detail.
558
+ """
559
+ residual = hidden_states
560
+
561
+ hidden_states = self.layer_norm1(hidden_states)
562
+ hidden_states, attn_weights = self.self_attn(
563
+ hidden_states=hidden_states,
564
+ cu_seqlens=cu_seqlens,
565
+ rotary_pos_emb=rotary_pos_emb,
566
+ position_embeddings=position_embeddings,
567
+ )
568
+ hidden_states = residual + hidden_states
569
+
570
+ residual = hidden_states
571
+ hidden_states = self.layer_norm2(hidden_states)
572
+ hidden_states = self.mlp(hidden_states)
573
+ hidden_states = residual + hidden_states
574
+
575
+ outputs = (hidden_states,)
576
+
577
+ if output_attentions:
578
+ outputs += (attn_weights,)
579
+
580
+ return outputs
581
+
582
+ class VisionRope(nn.Module):
583
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
584
+ super().__init__()
585
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
586
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
587
+
588
+ def forward(self, seqlen: int) -> torch.Tensor:
589
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
590
+ freqs = torch.outer(seq, self.inv_freq)
591
+ return freqs
592
+
593
+ class Siglip2Encoder(nn.Module):
594
+ """
595
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
596
+ [`Siglip2EncoderLayer`].
597
+
598
+ Args:
599
+ config: Siglip2Config
600
+ """
601
+
602
+ def __init__(self, config: Siglip2Config):
603
+ super().__init__()
604
+ self.config = config
605
+ self.layers = nn.ModuleList([Siglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
606
+ self.gradient_checkpointing = False
607
+ self.spatial_merge_size = 2
608
+ self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
609
+ self.patch_size = config.patch_size
610
+ self.window_size = self.patch_size * 2 * 8
611
+
612
+ assert(config.hidden_size%(config.num_attention_heads*2) == 0)
613
+ self.rotary_pos_emb = VisionRope(config.hidden_size//config.num_attention_heads//2)
614
+
615
+ def rot_pos_emb(self, spatial_shapes):
616
+ pos_ids = []
617
+
618
+ for h, w in spatial_shapes:
619
+ t = 1
620
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
621
+
622
+ hpos_ids = hpos_ids.reshape(
623
+ h // self.spatial_merge_size,
624
+ self.spatial_merge_size,
625
+ w // self.spatial_merge_size,
626
+ self.spatial_merge_size,
627
+ )
628
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
629
+ hpos_ids = hpos_ids.flatten()
630
+
631
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
632
+ wpos_ids = wpos_ids.reshape(
633
+ h // self.spatial_merge_size,
634
+ self.spatial_merge_size,
635
+ w // self.spatial_merge_size,
636
+ self.spatial_merge_size,
637
+ )
638
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
639
+ wpos_ids = wpos_ids.flatten()
640
+
641
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
642
+ pos_ids = torch.cat(pos_ids, dim=0)
643
+ max_grid_size = spatial_shapes.max()
644
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
645
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
646
+ return rotary_pos_emb
647
+
648
+ def get_window_index(self, spatial_shapes):
649
+ window_index: list = []
650
+ cu_window_seqlens: list = [0]
651
+ window_index_id = 0
652
+ vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size
653
+
654
+ for grid_h, grid_w in spatial_shapes:
655
+ grid_t = 1
656
+ llm_grid_h, llm_grid_w = (
657
+ grid_h // self.spatial_merge_size,
658
+ grid_w // self.spatial_merge_size,
659
+ )
660
+ index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
661
+ pad_h = (vit_merger_window_size - llm_grid_h % vit_merger_window_size) % vit_merger_window_size
662
+ pad_w = (vit_merger_window_size - llm_grid_w % vit_merger_window_size) % vit_merger_window_size
663
+ num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
664
+ num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
665
+ index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
666
+ index_padded = index_padded.reshape(
667
+ grid_t,
668
+ num_windows_h,
669
+ vit_merger_window_size,
670
+ num_windows_w,
671
+ vit_merger_window_size,
672
+ )
673
+
674
+ index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
675
+ grid_t,
676
+ num_windows_h * num_windows_w,
677
+ vit_merger_window_size,
678
+ vit_merger_window_size,
679
+ )
680
+ seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
681
+ index_padded = index_padded.reshape(-1)
682
+ index_new = index_padded[index_padded != -100]
683
+ window_index.append(index_new + window_index_id)
684
+ cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
685
+ cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
686
+ window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
687
+
688
+ window_index = torch.cat(window_index, dim=0)
689
+
690
+ return window_index, cu_window_seqlens
691
+
692
+ @can_return_tuple
693
+ def forward(
694
+ self,
695
+ inputs_embeds,
696
+ spatial_shapes: torch.LongTensor,
697
+ attention_mask: Optional[torch.Tensor] = None,
698
+ output_attentions: Optional[bool] = None,
699
+ output_hidden_states: Optional[bool] = None,
700
+ ) -> BaseModelOutput:
701
+ r"""
702
+ Args:
703
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
704
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
705
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
706
+ than the model's internal embedding lookup matrix.
707
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
708
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
709
+
710
+ - 1 for tokens that are **not masked**,
711
+ - 0 for tokens that are **masked**.
712
+
713
+ [What are attention masks?](../glossary#attention-mask)
714
+ output_attentions (`bool`, *optional*):
715
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
716
+ returned tensors for more detail.
717
+ output_hidden_states (`bool`, *optional*):
718
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
719
+ for more detail.
720
+ return_dict (`bool`, *optional*):
721
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
722
+ """
723
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
724
+ output_hidden_states = (
725
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
726
+ )
727
+
728
+ encoder_states = () if output_hidden_states else None
729
+ all_attentions = () if output_attentions else None
730
+
731
+ hidden_states = inputs_embeds
732
+ rotary_pos_emb = self.rot_pos_emb(spatial_shapes)
733
+ window_index, cu_window_seqlens = self.get_window_index(spatial_shapes)
734
+ cu_window_seqlens = torch.tensor(
735
+ cu_window_seqlens,
736
+ device=hidden_states.device,
737
+ dtype=spatial_shapes.dtype if torch.jit.is_tracing() else torch.int32,
738
+ )
739
+ cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
740
+
741
+ seq_len, _ = hidden_states.size()
742
+ hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
743
+ hidden_states = hidden_states[window_index, :, :]
744
+ hidden_states = hidden_states.reshape(seq_len, -1)
745
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
746
+ rotary_pos_emb = rotary_pos_emb[window_index, :, :]
747
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
748
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
749
+ position_embeddings = (emb.cos(), emb.sin())
750
+
751
+ cu_seqlens = torch.repeat_interleave(spatial_shapes[:, 0] * spatial_shapes[:, 1], 1).cumsum(
752
+ dim=0,
753
+ dtype=spatial_shapes.dtype if torch.jit.is_tracing() else torch.int32,
754
+ )
755
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
756
+
757
+ for layer_num, encoder_layer in enumerate(self.layers):
758
+ if output_hidden_states:
759
+ encoder_states = encoder_states + (hidden_states,)
760
+
761
+ if (1+layer_num) % 8 == 0 or layer_num == len(self.layers) - 1:
762
+ cu_seqlens_now = cu_seqlens
763
+ else:
764
+ cu_seqlens_now = cu_window_seqlens
765
+
766
+ layer_outputs = encoder_layer(
767
+ hidden_states,
768
+ attention_mask,
769
+ cu_seqlens=cu_seqlens_now,
770
+ position_embeddings=position_embeddings
771
+ )
772
+
773
+ hidden_states = layer_outputs[0]
774
+ if output_attentions:
775
+ all_attentions = all_attentions + (layer_outputs[1],)
776
+
777
+
778
+ hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
779
+ reverse_indices = torch.argsort(window_index)
780
+ hidden_states = hidden_states[reverse_indices, :, :]
781
+ hidden_states = hidden_states.reshape(seq_len, -1)
782
+
783
+ if output_hidden_states:
784
+ encoder_states = encoder_states + (hidden_states,)
785
+
786
+ return BaseModelOutput(
787
+ last_hidden_state=hidden_states,
788
+ hidden_states=encoder_states,
789
+ attentions=all_attentions,
790
+ )
791
+
792
+
793
+ SIGLIP2_VISION_INPUTS_DOCSTRING = r"""
794
+ Args:
795
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
796
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
797
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
798
+ output_attentions (`bool`, *optional*):
799
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
800
+ tensors for more detail.
801
+ output_hidden_states (`bool`, *optional*):
802
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
803
+ more detail.
804
+ interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
805
+ Whether to interpolate the pre-trained position encodings.
806
+ return_dict (`bool`, *optional*):
807
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
808
+ """
809
+
810
+
811
+ class Siglip2VisionTransformer(nn.Module):
812
+ def __init__(self, config: Siglip2VisionConfig):
813
+ super().__init__()
814
+ self.config = config
815
+ embed_dim = config.hidden_size
816
+
817
+ self.embeddings = Siglip2VisionEmbeddingsWoPos(config)
818
+ self.encoder = Siglip2Encoder(config)
819
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
820
+ self.use_head = False
821
+ if self.use_head:
822
+ self.head = Siglip2MultiheadAttentionPoolingHead(config)
823
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
824
+
825
+ @can_return_tuple
826
+ @add_start_docstrings_to_model_forward(SIGLIP2_VISION_INPUTS_DOCSTRING)
827
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Siglip2VisionConfig)
828
+ def forward(
829
+ self,
830
+ pixel_values: torch.FloatTensor,
831
+ attention_mask: torch.Tensor,
832
+ spatial_shapes: torch.LongTensor,
833
+ output_attentions: Optional[bool] = None,
834
+ output_hidden_states: Optional[bool] = None,
835
+ ) -> BaseModelOutputWithPooling:
836
+ r"""
837
+ Returns:
838
+
839
+ """
840
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
841
+ output_hidden_states = (
842
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
843
+ )
844
+
845
+ bs, length, dim = pixel_values.shape
846
+ hidden_states = self.embeddings(pixel_values)
847
+
848
+ if attention_mask is not None and not self._use_flash_attention_2:
849
+ encoder_attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
850
+ else:
851
+ encoder_attention_mask = attention_mask
852
+
853
+ encoder_outputs: BaseModelOutput = self.encoder(
854
+ inputs_embeds=hidden_states,
855
+ spatial_shapes=spatial_shapes,
856
+ attention_mask=encoder_attention_mask,
857
+ output_attentions=output_attentions,
858
+ output_hidden_states=output_hidden_states,
859
+ )
860
+
861
+ last_hidden_state = encoder_outputs.last_hidden_state
862
+
863
+ last_hidden_state = self.post_layernorm(last_hidden_state)
864
+
865
+ return BaseModelOutputWithPooling(
866
+ last_hidden_state=last_hidden_state,
867
+ pooler_output=None,
868
+ hidden_states=encoder_outputs.hidden_states,
869
+ attentions=encoder_outputs.attentions,
870
+ )
871
+
872
+
873
+ class Siglip2TextEmbeddings(nn.Module):
874
+ def __init__(self, config: Siglip2TextConfig):
875
+ super().__init__()
876
+ embed_dim = config.hidden_size
877
+
878
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
879
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
880
+
881
+ self.register_buffer(
882
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
883
+ )
884
+
885
+ def forward(
886
+ self,
887
+ input_ids: Optional[torch.LongTensor] = None,
888
+ position_ids: Optional[torch.LongTensor] = None,
889
+ inputs_embeds: Optional[torch.FloatTensor] = None,
890
+ ) -> torch.Tensor:
891
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
892
+ max_position_embedding = self.position_embedding.weight.shape[0]
893
+ if seq_length > max_position_embedding:
894
+ raise ValueError(
895
+ f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
896
+ f"{seq_length} and max_position_embeddings: {max_position_embedding}"
897
+ )
898
+
899
+ if position_ids is None:
900
+ position_ids = self.position_ids[:, :seq_length]
901
+
902
+ if inputs_embeds is None:
903
+ inputs_embeds = self.token_embedding(input_ids)
904
+
905
+ position_embeddings = self.position_embedding(position_ids)
906
+ embeddings = inputs_embeds + position_embeddings
907
+
908
+ return embeddings
909
+
910
+
911
+ def _trunc_normal_(tensor, mean, std, a, b):
912
+ def norm_cdf(x):
913
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
914
+
915
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
916
+ warnings.warn(
917
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
918
+ "The distribution of values may be incorrect.",
919
+ stacklevel=2,
920
+ )
921
+
922
+ l = norm_cdf((a - mean) / std)
923
+ u = norm_cdf((b - mean) / std)
924
+
925
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
926
+
927
+ tensor.erfinv_()
928
+
929
+ tensor.mul_(std * math.sqrt(2.0))
930
+ tensor.add_(mean)
931
+
932
+ tensor.clamp_(min=a, max=b)
933
+
934
+
935
+ def trunc_normal_tf_(
936
+ tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
937
+ ) -> torch.Tensor:
938
+ """
939
+ Args:
940
+ tensor: an n-dimensional `torch.Tensor`
941
+ mean: the mean of the normal distribution
942
+ std: the standard deviation of the normal distribution
943
+ a: the minimum cutoff value
944
+ b: the maximum cutoff value
945
+ """
946
+ with torch.no_grad():
947
+ _trunc_normal_(tensor, 0, 1.0, a, b)
948
+ tensor.mul_(std).add_(mean)
949
+
950
+
951
+ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
952
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
953
+ if mode == "fan_in":
954
+ denom = fan_in
955
+ elif mode == "fan_out":
956
+ denom = fan_out
957
+ elif mode == "fan_avg":
958
+ denom = (fan_in + fan_out) / 2
959
+
960
+ variance = scale / denom
961
+
962
+ if distribution == "truncated_normal":
963
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
964
+ elif distribution == "normal":
965
+ with torch.no_grad():
966
+ tensor.normal_(std=math.sqrt(variance))
967
+ elif distribution == "uniform":
968
+ bound = math.sqrt(3 * variance)
969
+ with torch.no_grad():
970
+ tensor.uniform_(-bound, bound)
971
+ else:
972
+ raise ValueError(f"invalid distribution {distribution}")
973
+
974
+
975
+ def lecun_normal_(tensor):
976
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
977
+
978
+
979
+ def default_flax_embed_init(tensor):
980
+ variance_scaling_(tensor, mode="fan_in", distribution="normal")
981
+
982
+
983
+ SIGLIP2_TEXT_INPUTS_DOCSTRING = r"""
984
+ Args:
985
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
986
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
987
+ it.
988
+
989
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
990
+ [`PreTrainedTokenizer.__call__`] for details.
991
+
992
+ [What are input IDs?](../glossary#input-ids)
993
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
994
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
995
+
996
+ - 1 for tokens that are **not masked**,
997
+ - 0 for tokens that are **masked**.
998
+
999
+ [What are attention masks?](../glossary#attention-mask)
1000
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1001
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1002
+ config.max_position_embeddings - 1]`.
1003
+
1004
+ [What are position IDs?](../glossary#position-ids)
1005
+ output_attentions (`bool`, *optional*):
1006
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1007
+ tensors for more detail.
1008
+ output_hidden_states (`bool`, *optional*):
1009
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1010
+ more detail.
1011
+ return_dict (`bool`, *optional*):
1012
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1013
+ """
1014
+
1015
+
1016
+ class Siglip2TextTransformer(nn.Module):
1017
+ def __init__(self, config: Siglip2TextConfig):
1018
+ super().__init__()
1019
+ self.config = config
1020
+ embed_dim = config.hidden_size
1021
+ self.embeddings = Siglip2TextEmbeddings(config)
1022
+ self.encoder = Siglip2Encoder(config)
1023
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
1024
+
1025
+ self.head = nn.Linear(embed_dim, config.projection_size)
1026
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1027
+
1028
+ @can_return_tuple
1029
+ @add_start_docstrings_to_model_forward(SIGLIP2_TEXT_INPUTS_DOCSTRING)
1030
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Siglip2TextConfig)
1031
+ def forward(
1032
+ self,
1033
+ input_ids: Optional[torch.Tensor] = None,
1034
+ attention_mask: Optional[torch.Tensor] = None,
1035
+ position_ids: Optional[torch.Tensor] = None,
1036
+ output_attentions: Optional[bool] = None,
1037
+ output_hidden_states: Optional[bool] = None,
1038
+ ) -> BaseModelOutputWithPooling:
1039
+ r"""
1040
+ Returns:
1041
+
1042
+ """
1043
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1044
+ output_hidden_states = (
1045
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1046
+ )
1047
+
1048
+ if input_ids is None:
1049
+ raise ValueError("You have to specify input_ids")
1050
+
1051
+ input_shape = input_ids.size()
1052
+ input_ids = input_ids.view(-1, input_shape[-1])
1053
+
1054
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
1055
+
1056
+ if attention_mask is not None and not self._use_flash_attention_2:
1057
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
1058
+
1059
+ encoder_outputs: BaseModelOutput = self.encoder(
1060
+ inputs_embeds=hidden_states,
1061
+ attention_mask=attention_mask,
1062
+ output_attentions=output_attentions,
1063
+ output_hidden_states=output_hidden_states,
1064
+ )
1065
+
1066
+ last_hidden_state = encoder_outputs.last_hidden_state
1067
+
1068
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
1069
+
1070
+ pooled_output = last_hidden_state[:, -1, :]
1071
+ pooled_output = self.head(pooled_output)
1072
+
1073
+ return BaseModelOutputWithPooling(
1074
+ last_hidden_state=last_hidden_state,
1075
+ pooler_output=pooled_output,
1076
+ hidden_states=encoder_outputs.hidden_states,
1077
+ attentions=encoder_outputs.attentions,
1078
+ )
1079
+
1080
+
1081
+ SIGLIP2_START_DOCSTRING = r"""
1082
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1083
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1084
+ etc.)
1085
+
1086
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1087
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1088
+ and behavior.
1089
+
1090
+ Parameters:
1091
+ config ([`Siglip2Config`]): Model configuration class with all the parameters of the model.
1092
+ Initializing with a config file does not load the weights associated with the model, only the
1093
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1094
+ """
1095
+
1096
+ SIGLIP2_INPUTS_DOCSTRING = r"""
1097
+ Args:
1098
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1099
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1100
+ it.
1101
+
1102
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1103
+ [`PreTrainedTokenizer.__call__`] for details.
1104
+
1105
+ [What are input IDs?](../glossary#input-ids)
1106
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1107
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1108
+
1109
+ - 1 for tokens that are **not masked**,
1110
+ - 0 for tokens that are **masked**.
1111
+
1112
+ [What are attention masks?](../glossary#attention-mask)
1113
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1114
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1115
+ config.max_position_embeddings - 1]`.
1116
+
1117
+ [What are position IDs?](../glossary#position-ids)
1118
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
1119
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
1120
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
1121
+ return_loss (`bool`, *optional*):
1122
+ Whether or not to return the contrastive loss.
1123
+ output_attentions (`bool`, *optional*):
1124
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1125
+ tensors for more detail.
1126
+ output_hidden_states (`bool`, *optional*):
1127
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1128
+ more detail.
1129
+ interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
1130
+ Whether to interpolate the pre-trained position encodings.
1131
+ return_dict (`bool`, *optional*):
1132
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1133
+ """
1134
+
1135
+
1136
+ class Siglip2PreTrainedModel(PreTrainedModel):
1137
+
1138
+ config_class = Siglip2Config
1139
+ base_model_prefix = "siglip2"
1140
+ supports_gradient_checkpointing = True
1141
+
1142
+ _no_split_modules = [
1143
+ "Siglip2TextEmbeddings",
1144
+ "Siglip2EncoderLayer",
1145
+ "Siglip2VisionEmbeddings",
1146
+ "Siglip2EncoderLayer",
1147
+ "Siglip2MultiheadAttentionPoolingHead",
1148
+ ]
1149
+ _supports_flash_attn_2 = True
1150
+ _supports_sdpa = True
1151
+
1152
+ def _init_weights(self, module):
1153
+ """Initialize the weights"""
1154
+ if isinstance(module, Siglip2VisionEmbeddings):
1155
+ width = (
1156
+ self.config.vision_config.hidden_size
1157
+ if isinstance(self.config, Siglip2Config)
1158
+ else self.config.hidden_size
1159
+ )
1160
+ nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
1161
+ elif isinstance(module, nn.Embedding):
1162
+ default_flax_embed_init(module.weight)
1163
+ elif isinstance(module, Siglip2Attention):
1164
+ nn.init.xavier_uniform_(module.q_proj.weight)
1165
+ nn.init.xavier_uniform_(module.k_proj.weight)
1166
+ nn.init.xavier_uniform_(module.v_proj.weight)
1167
+ nn.init.xavier_uniform_(module.out_proj.weight)
1168
+ nn.init.zeros_(module.q_proj.bias)
1169
+ nn.init.zeros_(module.k_proj.bias)
1170
+ nn.init.zeros_(module.v_proj.bias)
1171
+ nn.init.zeros_(module.out_proj.bias)
1172
+ elif isinstance(module, Siglip2MLP):
1173
+ nn.init.xavier_uniform_(module.fc1.weight)
1174
+ nn.init.xavier_uniform_(module.fc2.weight)
1175
+ nn.init.normal_(module.fc1.bias, std=1e-6)
1176
+ nn.init.normal_(module.fc2.bias, std=1e-6)
1177
+ elif isinstance(module, Siglip2MultiheadAttentionPoolingHead):
1178
+ nn.init.xavier_uniform_(module.probe.data)
1179
+ nn.init.xavier_uniform_(module.attention.in_proj_weight.data)
1180
+ nn.init.zeros_(module.attention.in_proj_bias.data)
1181
+ elif isinstance(module, Siglip2Model):
1182
+ logit_scale_init = torch.log(torch.tensor(1.0))
1183
+ module.logit_scale.data.fill_(logit_scale_init)
1184
+ module.logit_bias.data.zero_()
1185
+ elif isinstance(module, Siglip2ForImageClassification):
1186
+ nn.init.normal_(
1187
+ module.classifier.weight,
1188
+ std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor,
1189
+ )
1190
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
1191
+ lecun_normal_(module.weight)
1192
+ if module.bias is not None:
1193
+ nn.init.zeros_(module.bias)
1194
+ elif isinstance(module, nn.LayerNorm):
1195
+ module.bias.data.zero_()
1196
+ module.weight.data.fill_(1.0)
1197
+
1198
+
1199
+ @add_start_docstrings(
1200
+ """The text model from Siglip2 without any head or projection on top.""",
1201
+ SIGLIP2_START_DOCSTRING,
1202
+ )
1203
+ class Siglip2TextModel(Siglip2PreTrainedModel):
1204
+ config_class = Siglip2TextConfig
1205
+
1206
+ def __init__(self, config: Siglip2TextConfig):
1207
+ super().__init__(config)
1208
+ self.text_model = Siglip2TextTransformer(config)
1209
+ # Initialize weights and apply final processing
1210
+ self.post_init()
1211
+
1212
+ def get_input_embeddings(self) -> nn.Module:
1213
+ return self.text_model.embeddings.token_embedding
1214
+
1215
+ def set_input_embeddings(self, value):
1216
+ self.text_model.embeddings.token_embedding = value
1217
+
1218
+ @can_return_tuple
1219
+ @add_start_docstrings_to_model_forward(SIGLIP2_TEXT_INPUTS_DOCSTRING)
1220
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Siglip2TextConfig)
1221
+ def forward(
1222
+ self,
1223
+ input_ids: Optional[torch.Tensor] = None,
1224
+ attention_mask: Optional[torch.Tensor] = None,
1225
+ position_ids: Optional[torch.Tensor] = None,
1226
+ output_attentions: Optional[bool] = None,
1227
+ output_hidden_states: Optional[bool] = None,
1228
+ ) -> BaseModelOutputWithPooling:
1229
+ r"""
1230
+ Returns:
1231
+
1232
+ """
1233
+
1234
+ return self.text_model(
1235
+ input_ids=input_ids,
1236
+ attention_mask=attention_mask,
1237
+ position_ids=position_ids,
1238
+ output_attentions=output_attentions,
1239
+ output_hidden_states=output_hidden_states,
1240
+ )
1241
+
1242
+
1243
+ class Siglip2MultiheadAttentionPoolingHead(nn.Module):
1244
+ """Multihead Attention Pooling."""
1245
+
1246
+ def __init__(self, config: Siglip2VisionConfig):
1247
+ super().__init__()
1248
+
1249
+ self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
1250
+ self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
1251
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1252
+ self.mlp = Siglip2MLP(config)
1253
+ self.num_heads = config.num_attention_heads
1254
+
1255
+ def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
1256
+ batch_size = hidden_state.shape[0]
1257
+ probe = self.probe.repeat(batch_size, 1, 1)
1258
+
1259
+ if attention_mask is not None:
1260
+ target_len, source_len = probe.shape[1], hidden_state.shape[1]
1261
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_state.dtype, target_len)
1262
+ attention_mask = attention_mask.repeat(1, self.num_heads, target_len, 1)
1263
+ attention_mask = attention_mask.reshape(-1, target_len, source_len)
1264
+
1265
+ hidden_state = self.attention(probe, hidden_state, hidden_state, attn_mask=attention_mask)[0]
1266
+
1267
+ residual = hidden_state
1268
+ hidden_state = self.layernorm(hidden_state)
1269
+ hidden_state = residual + self.mlp(hidden_state)
1270
+
1271
+ return hidden_state[:, 0]
1272
+
1273
+
1274
+ @add_start_docstrings(
1275
+ """The vision model from Siglip2 without any head or projection on top.""",
1276
+ SIGLIP2_START_DOCSTRING,
1277
+ )
1278
+ class Siglip2VisionModel(Siglip2PreTrainedModel):
1279
+ config_class = Siglip2VisionConfig
1280
+ main_input_name = "pixel_values"
1281
+
1282
+ def __init__(self, config: Siglip2VisionConfig):
1283
+ super().__init__(config)
1284
+
1285
+ self.vision_model = Siglip2VisionTransformer(config)
1286
+
1287
+ # Initialize weights and apply final processing
1288
+ self.post_init()
1289
+
1290
+ def get_input_embeddings(self) -> nn.Module:
1291
+ return self.vision_model.embeddings.patch_embedding
1292
+
1293
+ @can_return_tuple
1294
+ @add_start_docstrings_to_model_forward(SIGLIP2_VISION_INPUTS_DOCSTRING)
1295
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Siglip2VisionConfig)
1296
+ def forward(
1297
+ self,
1298
+ pixel_values: torch.FloatTensor,
1299
+ pixel_attention_mask: torch.Tensor,
1300
+ spatial_shapes: torch.LongTensor,
1301
+ output_attentions: Optional[bool] = None,
1302
+ output_hidden_states: Optional[bool] = None,
1303
+ ) -> BaseModelOutputWithPooling:
1304
+ r"""
1305
+ Returns:
1306
+
1307
+ ```"""
1308
+ return self.vision_model(
1309
+ pixel_values=pixel_values,
1310
+ attention_mask=pixel_attention_mask,
1311
+ spatial_shapes=spatial_shapes,
1312
+ output_attentions=output_attentions,
1313
+ output_hidden_states=output_hidden_states,
1314
+ )
1315
+
1316
+
1317
+ @add_start_docstrings(SIGLIP2_START_DOCSTRING)
1318
+ class Siglip2Model(Siglip2PreTrainedModel):
1319
+ config_class = Siglip2Config
1320
+
1321
+ def __init__(self, config: Siglip2Config):
1322
+ super().__init__(config)
1323
+
1324
+ if not isinstance(config.text_config, Siglip2TextConfig):
1325
+ raise TypeError(
1326
+ "config.text_config is expected to be of type Siglip2TextConfig but is of type"
1327
+ f" {type(config.text_config)}."
1328
+ )
1329
+
1330
+ if not isinstance(config.vision_config, Siglip2VisionConfig):
1331
+ raise TypeError(
1332
+ "config.vision_config is expected to be of type Siglip2VisionConfig but is of type"
1333
+ f" {type(config.vision_config)}."
1334
+ )
1335
+
1336
+ text_config = config.text_config
1337
+ vision_config = config.vision_config
1338
+
1339
+ text_model = Siglip2TextModel._from_config(text_config)
1340
+ vision_model = Siglip2VisionModel._from_config(vision_config)
1341
+
1342
+ self.text_model = text_model.text_model
1343
+ self.vision_model = vision_model.vision_model
1344
+
1345
+ self.logit_scale = nn.Parameter(torch.randn(1))
1346
+ self.logit_bias = nn.Parameter(torch.randn(1))
1347
+
1348
+ self.post_init()
1349
+
1350
+ @add_start_docstrings_to_model_forward(SIGLIP2_TEXT_INPUTS_DOCSTRING)
1351
+ def get_text_features(
1352
+ self,
1353
+ input_ids: Optional[torch.Tensor] = None,
1354
+ attention_mask: Optional[torch.Tensor] = None,
1355
+ position_ids: Optional[torch.Tensor] = None,
1356
+ output_attentions: Optional[bool] = None,
1357
+ output_hidden_states: Optional[bool] = None,
1358
+ ) -> torch.FloatTensor:
1359
+ r"""
1360
+ Returns:
1361
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
1362
+ applying the projection layer to the pooled output of [`Siglip2TextModel`].
1363
+
1364
+ """
1365
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1366
+ output_hidden_states = (
1367
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1368
+ )
1369
+
1370
+ text_outputs: BaseModelOutputWithPooling = self.text_model(
1371
+ input_ids=input_ids,
1372
+ attention_mask=attention_mask,
1373
+ position_ids=position_ids,
1374
+ output_attentions=output_attentions,
1375
+ output_hidden_states=output_hidden_states,
1376
+ )
1377
+
1378
+ pooled_output = text_outputs.pooler_output
1379
+
1380
+ return pooled_output
1381
+
1382
+ @add_start_docstrings_to_model_forward(SIGLIP2_VISION_INPUTS_DOCSTRING)
1383
+ def get_image_features(
1384
+ self,
1385
+ pixel_values: Optional[torch.FloatTensor] = None,
1386
+ pixel_attention_mask: Optional[torch.Tensor] = None,
1387
+ spatial_shapes: Optional[torch.LongTensor] = None,
1388
+ output_attentions: Optional[bool] = None,
1389
+ output_hidden_states: Optional[bool] = None,
1390
+ ) -> torch.FloatTensor:
1391
+ r"""
1392
+ Returns:
1393
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
1394
+ applying the projection layer to the pooled output of [`Siglip2VisionModel`].
1395
+
1396
+ """
1397
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1398
+ output_hidden_states = (
1399
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1400
+ )
1401
+
1402
+ vision_outputs: BaseModelOutputWithPooling = self.vision_model(
1403
+ pixel_values=pixel_values,
1404
+ attention_mask=pixel_attention_mask,
1405
+ spatial_shapes=spatial_shapes,
1406
+ output_attentions=output_attentions,
1407
+ output_hidden_states=output_hidden_states,
1408
+ )
1409
+
1410
+ pooled_output = vision_outputs.pooler_output
1411
+
1412
+ return pooled_output
1413
+
1414
+ @can_return_tuple
1415
+ @add_start_docstrings_to_model_forward(SIGLIP2_INPUTS_DOCSTRING)
1416
+ @replace_return_docstrings(output_type=Siglip2Output, config_class=Siglip2Config)
1417
+ def forward(
1418
+ self,
1419
+ input_ids: Optional[torch.LongTensor] = None,
1420
+ pixel_values: Optional[torch.FloatTensor] = None,
1421
+ pixel_attention_mask: Optional[torch.Tensor] = None,
1422
+ spatial_shapes: Optional[torch.LongTensor] = None,
1423
+ attention_mask: Optional[torch.Tensor] = None,
1424
+ position_ids: Optional[torch.LongTensor] = None,
1425
+ return_loss: Optional[bool] = None,
1426
+ output_attentions: Optional[bool] = None,
1427
+ output_hidden_states: Optional[bool] = None,
1428
+ ) -> Siglip2Output:
1429
+ r"""
1430
+ Returns:
1431
+
1432
+ """
1433
+
1434
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1435
+ output_hidden_states = (
1436
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1437
+ )
1438
+
1439
+ vision_outputs: BaseModelOutputWithPooling = self.vision_model(
1440
+ pixel_values=pixel_values,
1441
+ attention_mask=pixel_attention_mask,
1442
+ spatial_shapes=spatial_shapes,
1443
+ output_attentions=output_attentions,
1444
+ output_hidden_states=output_hidden_states,
1445
+ )
1446
+
1447
+ text_outputs: BaseModelOutputWithPooling = self.text_model(
1448
+ input_ids=input_ids,
1449
+ attention_mask=attention_mask,
1450
+ position_ids=position_ids,
1451
+ output_attentions=output_attentions,
1452
+ output_hidden_states=output_hidden_states,
1453
+ )
1454
+
1455
+ image_embeds = vision_outputs.pooler_output
1456
+ text_embeds = text_outputs.pooler_output
1457
+
1458
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
1459
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
1460
+
1461
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device))
1462
+
1463
+ logit_scale, logit_bias = self.logit_scale.to(text_embeds.device), self.logit_bias.to(text_embeds.device)
1464
+ logits_per_text = logits_per_text * logit_scale.exp() + logit_bias
1465
+
1466
+ logits_per_image = logits_per_text.t()
1467
+
1468
+ loss = None
1469
+ if return_loss:
1470
+ eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device)
1471
+ m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye
1472
+ loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text)
1473
+ nll = -torch.sum(loglik, dim=-1)
1474
+ loss = nll.mean()
1475
+
1476
+ return Siglip2Output(
1477
+ loss=loss,
1478
+ logits_per_image=logits_per_image,
1479
+ logits_per_text=logits_per_text,
1480
+ text_embeds=text_embeds,
1481
+ image_embeds=image_embeds,
1482
+ text_model_output=text_outputs,
1483
+ vision_model_output=vision_outputs,
1484
+ )
1485
+
1486
+
1487
+ @add_start_docstrings(
1488
+ """
1489
+ Siglip2 vision encoder with an image classification head on top (a
1490
+ linear layer on top of the pooled final hidden states of
1491
+ the patch tokens) e.g. for ImageNet.
1492
+ """,
1493
+ SIGLIP2_START_DOCSTRING,
1494
+ )
1495
+ class Siglip2ForImageClassification(Siglip2PreTrainedModel):
1496
+ main_input_name = "pixel_values"
1497
+
1498
+ def __init__(self, config: Siglip2Config) -> None:
1499
+ super().__init__(config)
1500
+
1501
+ self.num_labels = config.num_labels
1502
+
1503
+ vision_model = Siglip2VisionModel._from_config(config.vision_config)
1504
+ self.vision_model = vision_model.vision_model
1505
+
1506
+ self.classifier = (
1507
+ nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
1508
+ )
1509
+
1510
+ self.post_init()
1511
+
1512
+ @can_return_tuple
1513
+ @add_start_docstrings_to_model_forward(SIGLIP2_INPUTS_DOCSTRING)
1514
+ @replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC)
1515
+ def forward(
1516
+ self,
1517
+ pixel_values: Optional[torch.Tensor] = None,
1518
+ pixel_attention_mask: Optional[torch.Tensor] = None,
1519
+ spatial_shapes: Optional[torch.LongTensor] = None,
1520
+ labels: Optional[torch.Tensor] = None,
1521
+ output_attentions: Optional[bool] = None,
1522
+ output_hidden_states: Optional[bool] = None,
1523
+ ) -> ImageClassifierOutput:
1524
+ r"""
1525
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1526
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
1527
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1528
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1529
+
1530
+ Returns:
1531
+
1532
+ """
1533
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1534
+ output_hidden_states = (
1535
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1536
+ )
1537
+
1538
+ outputs: BaseModelOutputWithPooling = self.vision_model(
1539
+ pixel_values,
1540
+ attention_mask=pixel_attention_mask,
1541
+ spatial_shapes=spatial_shapes,
1542
+ output_attentions=output_attentions,
1543
+ output_hidden_states=output_hidden_states,
1544
+ )
1545
+
1546
+ sequence_output = outputs.last_hidden_state
1547
+
1548
+ if pixel_attention_mask is not None:
1549
+ pool_mask = pixel_attention_mask[..., None].to(sequence_output.device)
1550
+ sequence_output = torch.sum(sequence_output * pool_mask, dim=1) / torch.sum(pool_mask, dim=1)
1551
+ else:
1552
+ sequence_output = torch.mean(sequence_output, dim=1)
1553
+
1554
+ logits = self.classifier(sequence_output)
1555
+
1556
+ loss = None
1557
+ if labels is not None:
1558
+ labels = labels.to(logits.device)
1559
+ if self.config.problem_type is None:
1560
+ if self.num_labels == 1:
1561
+ self.config.problem_type = "regression"
1562
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1563
+ self.config.problem_type = "single_label_classification"
1564
+ else:
1565
+ self.config.problem_type = "multi_label_classification"
1566
+
1567
+ if self.config.problem_type == "regression":
1568
+ loss_fct = MSELoss()
1569
+ if self.num_labels == 1:
1570
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1571
+ else:
1572
+ loss = loss_fct(logits, labels)
1573
+ elif self.config.problem_type == "single_label_classification":
1574
+ loss_fct = CrossEntropyLoss()
1575
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1576
+ elif self.config.problem_type == "multi_label_classification":
1577
+ loss_fct = BCEWithLogitsLoss()
1578
+ loss = loss_fct(logits, labels)
1579
+
1580
+ return ImageClassifierOutput(
1581
+ loss=loss,
1582
+ logits=logits,
1583
+ hidden_states=outputs.hidden_states,
1584
+ attentions=outputs.attentions,
1585
+ )
1586
+
1587
+
1588
+ __all__ = [
1589
+ "Siglip2Model",
1590
+ "Siglip2PreTrainedModel",
1591
+ "Siglip2TextModel",
1592
+ "Siglip2VisionModel",
1593
+ "Siglip2ForImageClassification",
1594
+ ]
modeling_youtu_vl.py ADDED
@@ -0,0 +1,1542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 Tencent Youtu lab, DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ import math
22
+ import os
23
+ from functools import partial
24
+ from typing import Callable, Optional, Tuple, Union, List, Any, Dict
25
+
26
+ import torch
27
+ import torch.nn.functional as F
28
+ from torch import nn
29
+
30
+ from transformers.activations import ACT2FN
31
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
32
+ from transformers.generation import GenerationMixin
33
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
34
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
35
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
36
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
37
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
38
+ from transformers.processing_utils import Unpack
39
+ from transformers.utils import (
40
+ add_start_docstrings,
41
+ add_start_docstrings_to_model_forward,
42
+ can_return_tuple,
43
+ is_torch_flex_attn_available,
44
+ logging,
45
+ replace_return_docstrings,
46
+ is_flash_attn_2_available,
47
+ )
48
+ from transformers.utils.deprecation import deprecate_kwarg
49
+ from .configuration_youtu_vl import YoutuVLConfig
50
+
51
+ from .modeling_siglip2 import Siglip2VisionModel, Siglip2VisionEmbeddings
52
+ from .configuration_siglip2 import Siglip2VisionConfig
53
+
54
+ if is_torch_flex_attn_available():
55
+ from torch.nn.attention.flex_attention import BlockMask
56
+ from transformers.integrations.flex_attention import make_flex_block_causal_mask
57
+
58
+ is_aiter_available = False
59
+
60
+ if is_flash_attn_2_available():
61
+ try:
62
+ from aiter import flash_attn_varlen_func
63
+ is_aiter_available = True
64
+ except ImportError:
65
+ from flash_attn import flash_attn_varlen_func
66
+ else:
67
+ flash_attn_varlen_func = None
68
+
69
+ logger = logging.get_logger(__name__)
70
+ _CONFIG_FOR_DOC = "YoutuVLConfig"
71
+
72
+ class YoutuRMSNorm(nn.Module):
73
+ def __init__(self, hidden_size, eps=1e-6):
74
+ super().__init__()
75
+ self.weight = nn.Parameter(torch.ones(hidden_size))
76
+ self.variance_epsilon = eps
77
+
78
+ def forward(self, hidden_states):
79
+ input_dtype = hidden_states.dtype
80
+
81
+ hidden_states = hidden_states.to(torch.float32)
82
+
83
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
84
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
85
+ return self.weight * hidden_states.to(input_dtype)
86
+
87
+ def extra_repr(self):
88
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
89
+
90
+
91
+ class YoutuRotaryEmbedding(nn.Module):
92
+ def __init__(self, config: YoutuVLConfig, device=None):
93
+ super().__init__()
94
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
95
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
96
+ else:
97
+ self.rope_type = "default"
98
+ self.max_seq_len_cached = config.max_position_embeddings
99
+ self.original_max_seq_len = config.max_position_embeddings
100
+
101
+ self.config = config
102
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
103
+
104
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
105
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
106
+ self.original_inv_freq = self.inv_freq
107
+
108
+ @torch.no_grad()
109
+ @dynamic_rope_update
110
+ def forward(self, x, position_ids):
111
+ """
112
+ Compute rotary positional embeddings.
113
+
114
+ Args:
115
+ x (torch.Tensor): Input tensor, shape (batch_size, seq_len, feature_dim)
116
+ position_ids (torch.LongTensor): Position indices, shape (batch_size, seq_len)
117
+
118
+ Returns:
119
+ Tuple of (cos, sin) tensors for rotary embedding
120
+ """
121
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
122
+ position_ids_expanded = position_ids[:, None, :].float()
123
+
124
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
125
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
126
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
127
+ emb = torch.cat((freqs, freqs), dim=-1)
128
+ cos = emb.cos() * self.attention_scaling
129
+ sin = emb.sin() * self.attention_scaling
130
+
131
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
132
+
133
+
134
+ class YoutuMLP(nn.Module):
135
+ def __init__(self, config, hidden_size=None, intermediate_size=None):
136
+ super().__init__()
137
+ self.config = config
138
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
139
+ self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
140
+
141
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
142
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
143
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
144
+ self.act_fn = ACT2FN[config.hidden_act]
145
+
146
+ def forward(self, x):
147
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
148
+ return down_proj
149
+
150
+
151
+ def rotate_half(x):
152
+ """
153
+ Rotates half the hidden dims of the input.
154
+ """
155
+ x1 = x[..., : x.shape[-1] // 2]
156
+ x2 = x[..., x.shape[-1] // 2 :]
157
+ return torch.cat((-x2, x1), dim=-1)
158
+
159
+
160
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
161
+ """Applies Rotary Position Embedding to the query and key tensors.
162
+
163
+ Args:
164
+ q (`torch.Tensor`): The query tensor.
165
+ k (`torch.Tensor`): The key tensor.
166
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
167
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
168
+ position_ids (`torch.Tensor`, *optional*):
169
+ Deprecated and unused.
170
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
171
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
172
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
173
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
174
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
175
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
176
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
177
+ Returns:
178
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
179
+ """
180
+ cos = cos.unsqueeze(unsqueeze_dim)
181
+ sin = sin.unsqueeze(unsqueeze_dim)
182
+ q_embed = (q * cos) + (rotate_half(q) * sin)
183
+ k_embed = (k * cos) + (rotate_half(k) * sin)
184
+ return q_embed, k_embed
185
+
186
+
187
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
188
+ """
189
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
190
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
191
+ """
192
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
193
+ if n_rep == 1:
194
+ return hidden_states
195
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
196
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
197
+
198
+
199
+ def eager_attention_forward(
200
+ module: nn.Module,
201
+ query: torch.Tensor,
202
+ key: torch.Tensor,
203
+ value: torch.Tensor,
204
+ attention_mask: Optional[torch.Tensor],
205
+ scaling: float,
206
+ dropout: float = 0.0,
207
+ **kwargs,
208
+ ):
209
+ key_states = repeat_kv(key, module.num_key_value_groups)
210
+ value_states = repeat_kv(value, module.num_key_value_groups)
211
+
212
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
213
+ if attention_mask is not None:
214
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
215
+ attn_weights = attn_weights + causal_mask
216
+
217
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
218
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
219
+ attn_output = torch.matmul(attn_weights, value_states)
220
+ attn_output = attn_output.transpose(1, 2).contiguous()
221
+
222
+ return attn_output, attn_weights
223
+
224
+
225
+ def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
226
+ r"""
227
+ Args:
228
+ q (`torch.Tensor`): The query tensor.
229
+ k (`torch.Tensor`): The key tensor.
230
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
231
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
232
+ position_ids (`torch.Tensor`):
233
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
234
+ used to pass offsetted position ids when working with a KV-cache.
235
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
236
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
237
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
238
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
239
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
240
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
241
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
242
+ Returns:
243
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
244
+ """
245
+ cos = cos.unsqueeze(unsqueeze_dim)
246
+ sin = sin.unsqueeze(unsqueeze_dim)
247
+
248
+ b, h, s, d = q.shape
249
+ q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
250
+
251
+ b, h, s, d = k.shape
252
+ k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
253
+
254
+ q_embed = (q * cos) + (rotate_half(q) * sin)
255
+ k_embed = (k * cos) + (rotate_half(k) * sin)
256
+ return q_embed, k_embed
257
+
258
+
259
+ def yarn_get_mscale(scale=1, mscale=1):
260
+ if scale <= 1:
261
+ return 1.0
262
+ return 0.1 * mscale * math.log(scale) + 1.0
263
+
264
+
265
+ class YoutuMLAttention(nn.Module):
266
+ """
267
+ Multi-latent attention from
268
+ 'DeepSeek-V2: A Strong, Economical,
269
+ and Efficient Mixture-of-Experts Language Model'paper
270
+ """
271
+
272
+ def __init__(self, config: YoutuVLConfig, layer_idx: int):
273
+ super().__init__()
274
+ self.config = config
275
+ self.layer_idx = layer_idx
276
+ self.num_key_value_groups = 1 # needed for eager attentions
277
+ self.attention_dropout = config.attention_dropout
278
+ self.num_heads = config.num_attention_heads
279
+ self.rope_theta = config.rope_theta
280
+ self.q_lora_rank = config.q_lora_rank
281
+ self.qk_rope_head_dim = config.qk_rope_head_dim
282
+ self.kv_lora_rank = config.kv_lora_rank
283
+ self.v_head_dim = config.v_head_dim
284
+ self.qk_nope_head_dim = config.qk_nope_head_dim
285
+ self.qk_head_dim = config.qk_head_dim
286
+ self.flash_att_sliding_window = config.flash_att_sliding_window
287
+ self.is_causal = True
288
+
289
+ if self.q_lora_rank is None:
290
+ self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False)
291
+ else:
292
+ self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias)
293
+ self.q_a_layernorm = YoutuRMSNorm(config.q_lora_rank)
294
+ self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False)
295
+
296
+ self.kv_a_proj_with_mqa = nn.Linear(
297
+ config.hidden_size,
298
+ self.kv_lora_rank + self.qk_rope_head_dim,
299
+ bias=config.attention_bias,
300
+ )
301
+ self.kv_a_layernorm = YoutuRMSNorm(self.kv_lora_rank)
302
+ self.kv_b_proj = nn.Linear(
303
+ self.kv_lora_rank,
304
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
305
+ bias=False,
306
+ )
307
+
308
+ self.o_proj = nn.Linear(
309
+ self.num_heads * self.v_head_dim,
310
+ config.hidden_size,
311
+ bias=config.attention_bias,
312
+ )
313
+
314
+ self.scaling = self.qk_head_dim ** (-0.5)
315
+ if self.config.rope_scaling is not None:
316
+ mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
317
+ scaling_factor = self.config.rope_scaling["factor"]
318
+ if mscale_all_dim:
319
+ mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
320
+ self.scaling = self.scaling * mscale * mscale
321
+
322
+ def forward(
323
+ self,
324
+ hidden_states: torch.Tensor,
325
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
326
+ attention_mask: Optional[torch.Tensor],
327
+ instance_length: Optional[torch.LongTensor] = None,
328
+ past_key_value: Optional[Cache] = None,
329
+ cache_position: Optional[torch.LongTensor] = None,
330
+ **kwargs: Unpack[FlashAttentionKwargs],
331
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
332
+ batch_size, seq_length = hidden_states.shape[:-1]
333
+ query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
334
+ key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim)
335
+
336
+ if self.q_lora_rank is None:
337
+ q_states = self.q_proj(hidden_states).view(query_shape).transpose(1, 2)
338
+ else:
339
+ q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(query_shape).transpose(1, 2)
340
+ q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
341
+
342
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
343
+ k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
344
+
345
+ k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2)
346
+ k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
347
+
348
+ k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
349
+
350
+ cos, sin = position_embeddings
351
+ if self.config.rope_interleave: # support using interleaved weights for efficiency
352
+ q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin)
353
+ else:
354
+ q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
355
+ k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
356
+
357
+ query_states = torch.cat((q_pass, q_rot), dim=-1)
358
+ key_states = torch.cat((k_pass, k_rot), dim=-1)
359
+
360
+ if past_key_value is not None:
361
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
362
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
363
+
364
+ if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
365
+ value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])
366
+
367
+ attention_interface: Callable = eager_attention_forward
368
+ if self.config._attn_implementation != "eager":
369
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
370
+ logger.warning_once(
371
+ "`torch.nn.functional.scaled_dot_product_attention` does not support"
372
+ "`output_attentions=True`. Falling back to 'eager attention. This warning"
373
+ 'can be removed using the argument `attn_implementation="eager"` when loading the model.'
374
+ )
375
+ else:
376
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
377
+
378
+ if instance_length is None or flash_attn_varlen_func is None:
379
+ attn_output, attn_weights = attention_interface(
380
+ self,
381
+ query_states,
382
+ key_states,
383
+ value_states,
384
+ attention_mask,
385
+ dropout=0.0 if not self.training else self.attention_dropout,
386
+ scaling=self.scaling,
387
+ **kwargs,
388
+ )
389
+ if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
390
+ attn_output = attn_output[:, :, :, : self.v_head_dim]
391
+ else:
392
+ instance_length = instance_length.view(-1)
393
+ query_states = query_states.squeeze(0).transpose(0,1)
394
+ key_states = key_states.squeeze(0).transpose(0,1)
395
+ value_states = value_states.squeeze(0).transpose(0,1)
396
+ max_seqlen_in_batch = instance_length.max().item()
397
+ cu_seqlens = F.pad(torch.cumsum(instance_length, dim=0, dtype=torch.int32), (1, 0))
398
+ if is_aiter_available:
399
+ attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_seqlens,
400
+ cu_seqlens, max_seqlen_in_batch, max_seqlen_in_batch,
401
+ dropout_p=0.0 if not self.training else self.attention_dropout,
402
+ softmax_scale=self.scaling,
403
+ causal=self.is_causal, return_lse=True)[0]
404
+ else:
405
+ attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_seqlens,
406
+ cu_seqlens, max_seqlen_in_batch, max_seqlen_in_batch,
407
+ dropout_p=0.0 if not self.training else self.attention_dropout,
408
+ softmax_scale=self.scaling,
409
+ causal=self.is_causal)
410
+
411
+ attn_output = attn_output.unsqueeze(0)
412
+ attn_output = attn_output[:, :, :, : self.v_head_dim]
413
+ attn_weights = None
414
+
415
+ attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
416
+ attn_output = self.o_proj(attn_output)
417
+ return attn_output, attn_weights
418
+
419
+
420
+ class YoutuDecoderLayer(nn.Module):
421
+ def __init__(self, config: YoutuVLConfig, layer_idx: int):
422
+ super().__init__()
423
+ self.hidden_size = config.hidden_size
424
+ self.self_attn = YoutuMLAttention(config=config, layer_idx=layer_idx)
425
+ self.mlp = YoutuMLP(config)
426
+ self.input_layernorm = YoutuRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
427
+ self.post_attention_layernorm = YoutuRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
428
+
429
+ def forward(
430
+ self,
431
+ hidden_states: torch.Tensor,
432
+ attention_mask: Optional[torch.Tensor] = None,
433
+ position_ids: Optional[torch.LongTensor] = None,
434
+ past_key_value: Optional[Cache] = None,
435
+ output_attentions: Optional[bool] = False,
436
+ instance_length: Optional[torch.LongTensor] = None,
437
+ use_cache: Optional[bool] = False,
438
+ cache_position: Optional[torch.LongTensor] = None,
439
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
440
+ **kwargs: Unpack[FlashAttentionKwargs],
441
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
442
+ residual = hidden_states
443
+
444
+ hidden_states = self.input_layernorm(hidden_states)
445
+
446
+ hidden_states, self_attn_weights = self.self_attn(
447
+ hidden_states=hidden_states,
448
+ attention_mask=attention_mask,
449
+ position_ids=position_ids,
450
+ past_key_value=past_key_value,
451
+ output_attentions=output_attentions,
452
+ instance_length=instance_length,
453
+ use_cache=use_cache,
454
+ cache_position=cache_position,
455
+ position_embeddings=position_embeddings,
456
+ **kwargs,
457
+ )
458
+ hidden_states = residual + hidden_states
459
+
460
+ residual = hidden_states
461
+ hidden_states = self.post_attention_layernorm(hidden_states)
462
+ hidden_states = self.mlp(hidden_states)
463
+ hidden_states = residual + hidden_states
464
+
465
+ outputs = (hidden_states,)
466
+ if output_attentions:
467
+ outputs += (self_attn_weights,)
468
+
469
+ return outputs
470
+
471
+
472
+ YOUTU_VL_START_DOCSTRING = r"""
473
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
474
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
475
+ etc.)
476
+
477
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
478
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
479
+ and behavior.
480
+
481
+ Parameters:
482
+ config ([`YoutuVLConfig`]):
483
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
484
+ load the weights associated with the model, only the configuration. Check out the
485
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
486
+ """
487
+
488
+
489
+ @add_start_docstrings(
490
+ "The bare Youtu Model outputting raw hidden-states without any specific head on top.",
491
+ YOUTU_VL_START_DOCSTRING,
492
+ )
493
+ class YoutuPreTrainedModel(PreTrainedModel):
494
+ config_class = YoutuVLConfig
495
+ base_model_prefix = "model"
496
+ supports_gradient_checkpointing = True
497
+ _no_split_modules = ["YoutuDecoderLayer"]
498
+ _skip_keys_device_placement = ["past_key_values"]
499
+ _supports_flash_attn_2 = True
500
+ _supports_sdpa = True
501
+ _supports_flex_attn = True
502
+ _supports_cache_class = True
503
+ _supports_quantized_cache = True
504
+ _supports_static_cache = True
505
+ _supports_attention_backend = True
506
+
507
+ def init_weights(self):
508
+ if self.config.pruned_heads:
509
+ self.prune_heads(self.config.pruned_heads)
510
+
511
+ if "-init" in self.name_or_path:
512
+ self.apply(self._initialize_weights)
513
+
514
+ for name, module in self.named_modules():
515
+ if "o_proj" in name or "down_proj" in name:
516
+ scaled_std = self.config.initializer_range * (1.0 / self.config.num_hidden_layers) ** 0.5
517
+ module.weight.data.normal_(mean=0.0, std=scaled_std)
518
+
519
+ self.tie_weights()
520
+
521
+ def _init_weights(self, module):
522
+ std = self.config.initializer_range
523
+ embedding_std = self.config.embedding_initializer_range
524
+ if isinstance(module, nn.Linear):
525
+ module.weight.data.normal_(mean=0.0, std=std)
526
+ if module.bias is not None:
527
+ module.bias.data.zero_()
528
+ elif isinstance(module, nn.Embedding):
529
+ module.weight.data.normal_(mean=0.0, std=embedding_std)
530
+ if module.padding_idx is not None:
531
+ module.weight.data[module.padding_idx].zero_()
532
+ elif isinstance(module, nn.Parameter):
533
+ module.weight.data.normal_(mean=0.0, std=std)
534
+ elif isinstance(module, YoutuRMSNorm):
535
+ module.weight.data.fill_(1.0)
536
+
537
+
538
+ YOUTU_VL_INPUTS_DOCSTRING = r"""
539
+ Args:
540
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
541
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
542
+ it.
543
+
544
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
545
+ [`PreTrainedTokenizer.__call__`] for details.
546
+
547
+ [What are input IDs?](../glossary#input-ids)
548
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
549
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
550
+
551
+ - 1 for tokens that are **not masked**,
552
+ - 0 for tokens that are **masked**.
553
+
554
+ [What are attention masks?](../glossary#attention-mask)
555
+
556
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
557
+ [`PreTrainedTokenizer.__call__`] for details.
558
+
559
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
560
+ `past_key_values`).
561
+
562
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
563
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
564
+ information on the default strategy.
565
+
566
+ - 1 indicates the head is **not masked**,
567
+ - 0 indicates the head is **masked**.
568
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
569
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
570
+ config.n_positions - 1]`.
571
+
572
+ [What are position IDs?](../glossary#position-ids)
573
+ past_key_values (`Cache`, *optional*):
574
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
575
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
576
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
577
+
578
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
579
+
580
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
581
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
582
+ of shape `(batch_size, sequence_length)`.
583
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
584
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
585
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
586
+ model's internal embedding lookup matrix.
587
+ use_cache (`bool`, *optional*):
588
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
589
+ `past_key_values`).
590
+ output_attentions (`bool`, *optional*):
591
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
592
+ tensors for more detail.
593
+ output_hidden_states (`bool`, *optional*):
594
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
595
+ more detail.
596
+ return_dict (`bool`, *optional*):
597
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
598
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
599
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
600
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
601
+ the complete sequence length.
602
+ """
603
+
604
+
605
+ @add_start_docstrings(
606
+ "The bare Youtu Model outputting raw hidden-states without any specific head on top.",
607
+ YOUTU_VL_START_DOCSTRING,
608
+ )
609
+ class YoutuModel(YoutuPreTrainedModel):
610
+ _keys_to_ignore_on_load_unexpected = [r"model\.layers\.61.*"]
611
+
612
+ def __init__(self, config: YoutuVLConfig):
613
+ super().__init__(config)
614
+ self.padding_idx = config.pad_token_id
615
+ self.vocab_size = config.vocab_size
616
+
617
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
618
+ self.layers = nn.ModuleList(
619
+ [YoutuDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
620
+ )
621
+ self.norm = YoutuRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
622
+ self.rotary_emb = YoutuRotaryEmbedding(config=config)
623
+ self.gradient_checkpointing = False
624
+
625
+ # Initialize weights and apply final processing
626
+ self.post_init()
627
+
628
+ def get_input_embeddings(self):
629
+ return self.embed_tokens
630
+
631
+ def set_input_embeddings(self, value):
632
+ self.embed_tokens = value
633
+
634
+ @can_return_tuple
635
+ @add_start_docstrings_to_model_forward(YOUTU_VL_INPUTS_DOCSTRING)
636
+ def forward(
637
+ self,
638
+ input_ids: Optional[torch.LongTensor] = None,
639
+ attention_mask: Optional[torch.Tensor] = None,
640
+ position_ids: Optional[torch.LongTensor] = None,
641
+ past_key_values: Optional[Cache] = None,
642
+ inputs_embeds: Optional[torch.FloatTensor] = None,
643
+ use_cache: Optional[bool] = None,
644
+ instance_length: Optional[torch.LongTensor] = None,
645
+ output_attentions: Optional[bool] = None,
646
+ output_hidden_states: Optional[bool] = None,
647
+ cache_position: Optional[torch.LongTensor] = None,
648
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
649
+ ) -> BaseModelOutputWithPast:
650
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
651
+ output_hidden_states = (
652
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
653
+ )
654
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
655
+
656
+ if (input_ids is None) ^ (inputs_embeds is not None):
657
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
658
+
659
+ if inputs_embeds is None:
660
+ inputs_embeds = self.embed_tokens(input_ids)
661
+
662
+ if use_cache and past_key_values is None:
663
+ past_key_values = DynamicCache()
664
+
665
+ if cache_position is None:
666
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
667
+ cache_position = torch.arange(
668
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
669
+ )
670
+
671
+ if position_ids is None:
672
+ position_ids = cache_position.unsqueeze(0)
673
+
674
+ causal_mask = self._update_causal_mask(
675
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
676
+ )
677
+
678
+ hidden_states = inputs_embeds
679
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
680
+
681
+ all_hidden_states = () if output_hidden_states else None
682
+ all_self_attns = () if output_attentions else None
683
+
684
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
685
+ if output_hidden_states:
686
+ all_hidden_states += (hidden_states,)
687
+ layer_outputs = decoder_layer(
688
+ hidden_states,
689
+ attention_mask=causal_mask,
690
+ position_ids=position_ids,
691
+ past_key_value=past_key_values,
692
+ output_attentions=output_attentions,
693
+ instance_length=instance_length,
694
+ use_cache=use_cache,
695
+ cache_position=cache_position,
696
+ position_embeddings=position_embeddings,
697
+ **flash_attn_kwargs,
698
+ )
699
+ hidden_states = layer_outputs[0]
700
+ if output_attentions:
701
+ all_self_attns += (layer_outputs[1],)
702
+
703
+ hidden_states = self.norm(hidden_states)
704
+
705
+ if output_hidden_states:
706
+ all_hidden_states += (hidden_states,)
707
+
708
+ return BaseModelOutputWithPast(
709
+ last_hidden_state=hidden_states,
710
+ past_key_values=past_key_values if use_cache else None,
711
+ hidden_states=all_hidden_states,
712
+ attentions=all_self_attns,
713
+ )
714
+
715
+ def _update_causal_mask(
716
+ self,
717
+ attention_mask: torch.Tensor,
718
+ input_tensor: torch.Tensor,
719
+ cache_position: torch.Tensor,
720
+ past_key_values: Cache,
721
+ output_attentions: bool = False,
722
+ ):
723
+ if self.config._attn_implementation == "flash_attention_2":
724
+ if attention_mask is not None and (attention_mask == 0.0).any():
725
+ return attention_mask
726
+ return None
727
+
728
+ if self.config._attn_implementation == "flex_attention":
729
+ if isinstance(attention_mask, torch.Tensor):
730
+ attention_mask = make_flex_block_causal_mask(attention_mask)
731
+ if isinstance(attention_mask, BlockMask):
732
+ return attention_mask
733
+
734
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
735
+ using_static_cache = isinstance(past_key_values, StaticCache)
736
+
737
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
738
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
739
+ attention_mask,
740
+ inputs_embeds=input_tensor,
741
+ past_key_values_length=past_seen_tokens,
742
+ is_training=self.training,
743
+ ):
744
+ return None
745
+
746
+ dtype, device = input_tensor.dtype, input_tensor.device
747
+ sequence_length = input_tensor.shape[1]
748
+ if using_static_cache:
749
+ target_length = past_key_values.get_max_cache_shape()
750
+ else:
751
+ target_length = (
752
+ attention_mask.shape[-1]
753
+ if isinstance(attention_mask, torch.Tensor)
754
+ else past_seen_tokens + sequence_length + 1
755
+ )
756
+
757
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
758
+ attention_mask,
759
+ sequence_length=sequence_length,
760
+ target_length=target_length,
761
+ dtype=dtype,
762
+ device=device,
763
+ cache_position=cache_position,
764
+ batch_size=input_tensor.shape[0],
765
+ )
766
+
767
+ if (
768
+ self.config._attn_implementation == "sdpa"
769
+ and attention_mask is not None
770
+ and attention_mask.device.type in ["cuda", "xpu"]
771
+ and not output_attentions
772
+ ):
773
+ min_dtype = torch.finfo(dtype).min
774
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
775
+
776
+ return causal_mask
777
+
778
+ @staticmethod
779
+ def _prepare_4d_causal_attention_mask_with_cache_position(
780
+ attention_mask: torch.Tensor,
781
+ sequence_length: int,
782
+ target_length: int,
783
+ dtype: torch.dtype,
784
+ device: torch.device,
785
+ cache_position: torch.Tensor,
786
+ batch_size: int,
787
+ **kwargs,
788
+ ):
789
+ """
790
+ Args:
791
+ attention_mask (`torch.Tensor`):
792
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
793
+ `(batch_size, 1, query_length, key_value_length)`.
794
+ sequence_length (`int`):
795
+ The sequence length being processed.
796
+ target_length (`int`):
797
+ The target length: when generating with static cache, the mask should be as long as the static cache,
798
+ to account for the 0 padding, the part of the cache that is not filled yet.
799
+ dtype (`torch.dtype`):
800
+ The dtype to use for the 4D attention mask.
801
+ device (`torch.device`):
802
+ The device to place the 4D attention mask on.
803
+ cache_position (`torch.Tensor`):
804
+ Indices depicting the position of the input sequence tokens in the sequence.
805
+ batch_size (`torch.Tensor`):
806
+ Batch size.
807
+ """
808
+ if attention_mask is not None and attention_mask.dim() == 4:
809
+ causal_mask = attention_mask
810
+ else:
811
+ min_dtype = torch.finfo(dtype).min
812
+ causal_mask = torch.full(
813
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
814
+ )
815
+ if sequence_length != 1:
816
+ causal_mask = torch.triu(causal_mask, diagonal=1)
817
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
818
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
819
+ if attention_mask is not None:
820
+ causal_mask = causal_mask.clone()
821
+ mask_length = attention_mask.shape[-1]
822
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
823
+ causal_mask.device
824
+ )
825
+ padding_mask = padding_mask == 0
826
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
827
+ padding_mask, min_dtype
828
+ )
829
+
830
+ return causal_mask
831
+
832
+
833
+ class KwargsForCausalLM(FlashAttentionKwargs): ...
834
+
835
+
836
+ class YoutuForCausalLM(YoutuPreTrainedModel, GenerationMixin):
837
+ _tied_weights_keys = ["lm_head.weight"]
838
+ _tp_plan = {"lm_head": "colwise_rep"}
839
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
840
+
841
+ def __init__(self, config):
842
+ super().__init__(config)
843
+
844
+ self.model = YoutuModel(config)
845
+ self.vocab_size = config.vocab_size
846
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
847
+
848
+ self.post_init()
849
+
850
+ def get_input_embeddings(self):
851
+ return self.model.embed_tokens
852
+
853
+ def set_input_embeddings(self, value):
854
+ self.model.embed_tokens = value
855
+
856
+ def get_output_embeddings(self):
857
+ return self.lm_head
858
+
859
+ def set_output_embeddings(self, new_embeddings):
860
+ self.lm_head = new_embeddings
861
+
862
+ def set_decoder(self, decoder):
863
+ self.model = decoder
864
+
865
+ def get_decoder(self):
866
+ return self.model
867
+
868
+ def get_merge_embedding(self, inputs_embeds, image_embeds, image_mask,**kwargs,):
869
+ bs, length, dim_size = inputs_embeds.shape
870
+ if image_embeds is None:
871
+ return inputs_embeds
872
+ if bs == 1:
873
+ image_embeds = image_embeds.unsqueeze(0)
874
+ init_inputs_embeds = inputs_embeds.clone()
875
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
876
+ cmp_mask = torch.isclose(init_inputs_embeds, inputs_embeds, rtol=1e-05, atol=1e-08)
877
+ else:
878
+ assert(bs==1)
879
+
880
+ return inputs_embeds
881
+ @can_return_tuple
882
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
883
+ @add_start_docstrings_to_model_forward(YOUTU_VL_INPUTS_DOCSTRING)
884
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
885
+ def forward(
886
+ self,
887
+ input_ids: Optional[torch.LongTensor] = None,
888
+ attention_mask: Optional[torch.Tensor] = None,
889
+ position_ids: Optional[torch.LongTensor] = None,
890
+ past_key_values: Optional[Cache] = None,
891
+ inputs_embeds: Optional[torch.FloatTensor] = None,
892
+ labels: Optional[torch.LongTensor] = None,
893
+ use_cache: Optional[bool] = None,
894
+ output_attentions: Optional[bool] = None,
895
+ output_hidden_states: Optional[bool] = None,
896
+ cache_position: Optional[torch.LongTensor] = None,
897
+ logits_to_keep: Union[int, torch.Tensor] = 0,
898
+ **kwargs: Unpack[KwargsForCausalLM],
899
+ ) -> CausalLMOutputWithPast:
900
+ r"""
901
+ Returns:
902
+
903
+ """
904
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
905
+ output_hidden_states = (
906
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
907
+ )
908
+
909
+ outputs: BaseModelOutputWithPast = self.model(
910
+ input_ids=input_ids,
911
+ attention_mask=attention_mask,
912
+ position_ids=position_ids,
913
+ past_key_values=past_key_values,
914
+ inputs_embeds=inputs_embeds,
915
+ use_cache=use_cache,
916
+ output_attentions=output_attentions,
917
+ output_hidden_states=output_hidden_states,
918
+ cache_position=cache_position,
919
+ **kwargs,
920
+ )
921
+
922
+ hidden_states = outputs.last_hidden_state
923
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
924
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
925
+
926
+ loss = None
927
+
928
+ return CausalLMOutputWithPast(
929
+ loss=loss,
930
+ logits=logits,
931
+ past_key_values=outputs.past_key_values,
932
+ hidden_states=outputs.hidden_states,
933
+ attentions=outputs.attentions,
934
+ )
935
+
936
+ class VLPatchMerger(nn.Module):
937
+ def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
938
+ super().__init__()
939
+ self.hidden_size = context_dim * (spatial_merge_size**2)
940
+ self.ln_q = YoutuRMSNorm(context_dim, eps=1e-06)
941
+ self.mlp = nn.Sequential(
942
+ nn.Linear(self.hidden_size, self.hidden_size),
943
+ nn.GELU(),
944
+ nn.Linear(self.hidden_size, dim),
945
+ )
946
+
947
+ def forward(self, x: torch.Tensor, spatial_shapes: torch.Tensor) -> torch.Tensor:
948
+ x = self.ln_q(x).view(-1, self.hidden_size)
949
+ x = self.mlp(x)
950
+ return x
951
+
952
+ class YoutuVLForConditionalGeneration(YoutuPreTrainedModel, GenerationMixin):
953
+ _tied_weights_keys = ["lm_head.weight"]
954
+ _tp_plan = {"lm_head": "colwise_rep"}
955
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
956
+
957
+ def __init__(self, config):
958
+ super().__init__(config)
959
+
960
+ config.vision_config.out_hidden_size = config.hidden_size
961
+ config.vision_config.vision_use_head = False
962
+ self.siglip2 = Siglip2VisionModel._from_config(config.vision_config)
963
+ self.merger = VLPatchMerger(
964
+ dim=config.hidden_size,
965
+ context_dim=config.vision_config.hidden_size,
966
+ spatial_merge_size=2,
967
+ )
968
+ self.rope_deltas = None
969
+
970
+ self.model = YoutuModel(config)
971
+ self.vocab_size = config.vocab_size
972
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
973
+ self.first_logits = None
974
+
975
+ self.post_init()
976
+
977
+
978
+ def get_input_embeddings(self):
979
+ return self.model.embed_tokens
980
+
981
+ def set_input_embeddings(self, value):
982
+ self.model.embed_tokens = value
983
+
984
+ def get_output_embeddings(self):
985
+ return self.lm_head
986
+
987
+ def set_output_embeddings(self, new_embeddings):
988
+ self.lm_head = new_embeddings
989
+
990
+ def set_decoder(self, decoder):
991
+ self.model = decoder
992
+
993
+ def get_decoder(self):
994
+ return self.model
995
+
996
+ def get_input_idx_embeddings(self, input_ids):
997
+ inputs_embeds = self.model.embed_tokens(input_ids)
998
+ return inputs_embeds
999
+
1000
+ def get_visiual_features(self, pixel_values, pixel_attention_mask, spatial_shapes):
1001
+ pixel_values = pixel_values.type(self.siglip2.dtype)
1002
+
1003
+ # Extract image embeddings via vision model
1004
+ image_embeds = self.siglip2(pixel_values, pixel_attention_mask, spatial_shapes).last_hidden_state
1005
+ # Merge image features with the output of vision model
1006
+ image_embeds = self.merger(image_embeds, spatial_shapes)
1007
+
1008
+ return image_embeds
1009
+
1010
+
1011
+ def get_merge_embedding(self, inputs_embeds, image_embeds, image_mask, **kwargs):
1012
+ """
1013
+ Merge text embeddings with image embeddings using the provided mask.
1014
+
1015
+ Args:
1016
+ inputs_embeds: Text input embeddings
1017
+ image_embeds: Image embeddings to merge
1018
+ image_mask: Mask indicating where to place image embeddings
1019
+ **kwargs: Additional keyword arguments
1020
+
1021
+ Returns:
1022
+ Merged embeddings with image features integrated
1023
+ """
1024
+ bs, length, dim_size = inputs_embeds.shape
1025
+ if image_embeds is None:
1026
+ return inputs_embeds
1027
+ if bs == 1:
1028
+ image_embeds = image_embeds.unsqueeze(0)
1029
+ init_inputs_embeds = inputs_embeds.clone()
1030
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
1031
+ cmp_mask = torch.isclose(init_inputs_embeds, inputs_embeds, rtol=1e-05, atol=1e-08)
1032
+ else:
1033
+ print('******************ERROR: if you see this info, only support batch_size==1*********************')
1034
+ assert(bs == 1)
1035
+
1036
+ return inputs_embeds
1037
+
1038
+ @can_return_tuple
1039
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
1040
+ @add_start_docstrings_to_model_forward(YOUTU_VL_INPUTS_DOCSTRING)
1041
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1042
+ def forward(
1043
+ self,
1044
+ input_ids: Optional[torch.LongTensor] = None,
1045
+ attention_mask: Optional[torch.Tensor] = None,
1046
+ position_ids: Optional[torch.LongTensor] = None,
1047
+ past_key_values: Optional[Cache] = None,
1048
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1049
+ labels: Optional[torch.LongTensor] = None,
1050
+ use_cache: Optional[bool] = None,
1051
+ output_attentions: Optional[bool] = None,
1052
+ output_hidden_states: Optional[bool] = None,
1053
+ pixel_values: Optional[torch.Tensor] = None,
1054
+ pixel_attention_mask: Optional[torch.LongTensor] = None,
1055
+ spatial_shapes: Optional[torch.LongTensor] = None,
1056
+ instance_length: Optional[torch.LongTensor] = None,
1057
+ coefficients: Optional[torch.FloatTensor] = None,
1058
+ rope_deltas: Optional[torch.LongTensor] = None,
1059
+ cache_position: Optional[torch.LongTensor] = None,
1060
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1061
+ **kwargs: Unpack[KwargsForCausalLM],
1062
+ ) -> CausalLMOutputWithPast:
1063
+ r"""
1064
+ Example:
1065
+ TODO: Add example
1066
+
1067
+ Returns:
1068
+ """
1069
+
1070
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1071
+ output_hidden_states = (
1072
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1073
+ )
1074
+
1075
+ if inputs_embeds is None:
1076
+ inputs_embeds = self.model.embed_tokens(input_ids)
1077
+
1078
+ if pixel_values is not None:
1079
+ bs, length, dim_size = inputs_embeds.shape
1080
+ pixel_values = pixel_values.type(self.siglip2.dtype)
1081
+
1082
+ image_embeds = self.siglip2(pixel_values, pixel_attention_mask, spatial_shapes).last_hidden_state
1083
+ image_embeds = self.merger(image_embeds, spatial_shapes)
1084
+
1085
+ n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
1086
+ n_image_features = image_embeds.shape[0]
1087
+
1088
+ if n_image_tokens > n_image_features:
1089
+ raise ValueError(
1090
+ "Image features and image tokens do not match: tokens: {}, features {}".format(
1091
+ n_image_tokens, n_image_features
1092
+ )
1093
+ )
1094
+
1095
+ mask = input_ids == self.config.image_token_id
1096
+ mask_unsqueezed = mask.unsqueeze(-1)
1097
+ mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
1098
+ image_mask = mask_expanded.to(inputs_embeds.device)
1099
+ image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
1100
+
1101
+ if bs != 1:
1102
+ raise ValueError("Only support batch size = 1")
1103
+
1104
+ image_embeds = image_embeds.unsqueeze(0)
1105
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
1106
+
1107
+ if attention_mask is not None:
1108
+ attention_mask = attention_mask.to(inputs_embeds.device)
1109
+
1110
+ outputs: BaseModelOutputWithPast = self.model(
1111
+ input_ids=None,
1112
+ attention_mask=attention_mask,
1113
+ position_ids=position_ids,
1114
+ past_key_values=past_key_values,
1115
+ inputs_embeds=inputs_embeds,
1116
+ use_cache=use_cache,
1117
+ output_attentions=output_attentions,
1118
+ output_hidden_states=output_hidden_states,
1119
+ cache_position=cache_position,
1120
+ instance_length=instance_length,
1121
+ **kwargs,
1122
+ )
1123
+
1124
+ hidden_states = outputs.last_hidden_state
1125
+ logits = self.lm_head(hidden_states)
1126
+ if logits.shape[1] != 1:
1127
+ self.first_logits = logits
1128
+ loss = None
1129
+
1130
+ return CausalLMOutputWithPast(
1131
+ loss=loss,
1132
+ logits=logits,
1133
+ past_key_values=outputs.past_key_values,
1134
+ hidden_states=outputs.hidden_states,
1135
+ attentions=outputs.attentions,
1136
+ )
1137
+ def truncate_past_key_values(
1138
+ self,
1139
+ past_key_values: Optional[DynamicCache],
1140
+ num_history: int
1141
+ ) -> Optional[DynamicCache]:
1142
+ """Truncate past_key_values to specified history length in-place.
1143
+
1144
+ Args:
1145
+ past_key_values: Cache object to truncate
1146
+ num_history: Target history length to keep
1147
+
1148
+ Returns:
1149
+ Truncated cache object or None if input is None
1150
+ """
1151
+ if past_key_values is None:
1152
+ return None
1153
+
1154
+ current_length = past_key_values.get_seq_length()
1155
+ if current_length <= num_history:
1156
+ return past_key_values
1157
+
1158
+ for layer_idx in range(len(past_key_values.key_cache)):
1159
+ if past_key_values.key_cache[layer_idx] is not None:
1160
+ past_key_values.key_cache[layer_idx] = (
1161
+ past_key_values.key_cache[layer_idx][:, :, :num_history, :].contiguous()
1162
+ )
1163
+ past_key_values.value_cache[layer_idx] = (
1164
+ past_key_values.value_cache[layer_idx][:, :, :num_history, :].contiguous()
1165
+ )
1166
+
1167
+ return past_key_values
1168
+
1169
+ def clone_past_key_values(
1170
+ self,
1171
+ past_key_values: Optional[DynamicCache]
1172
+ ) -> Optional[DynamicCache]:
1173
+ """Deep copy past_key_values to avoid shared reference issues.
1174
+
1175
+ Args:
1176
+ past_key_values: Cache object to clone
1177
+
1178
+ Returns:
1179
+ Deep copied cache object or None if input is None
1180
+ """
1181
+ if past_key_values is None:
1182
+ return None
1183
+
1184
+ new_cache = DynamicCache()
1185
+ for layer_idx in range(len(past_key_values.key_cache)):
1186
+ if past_key_values.key_cache[layer_idx] is not None:
1187
+ new_cache.key_cache.append(past_key_values.key_cache[layer_idx].clone())
1188
+ new_cache.value_cache.append(past_key_values.value_cache[layer_idx].clone())
1189
+
1190
+ return new_cache
1191
+
1192
+ def concat_token_ids(
1193
+ self,
1194
+ input_ids: torch.Tensor,
1195
+ concat_ids: Optional[List[int]]
1196
+ ) -> torch.Tensor:
1197
+ """Concatenate additional token IDs to input sequence.
1198
+
1199
+ Args:
1200
+ input_ids: Original input token IDs of shape (batch_size, seq_len)
1201
+ concat_ids: Token IDs to concatenate
1202
+
1203
+ Returns:
1204
+ Concatenated token IDs tensor
1205
+ """
1206
+ if concat_ids is None:
1207
+ return input_ids
1208
+
1209
+ num_gen = len(concat_ids)
1210
+ if num_gen < 2:
1211
+ return input_ids
1212
+
1213
+ batch_size = input_ids.size(0)
1214
+ concat_token_tensor = torch.tensor(
1215
+ concat_ids,
1216
+ dtype=input_ids.dtype,
1217
+ device=input_ids.device
1218
+ )
1219
+ concat_tokens = concat_token_tensor.unsqueeze(0).repeat(batch_size, 1)
1220
+ new_input_ids = torch.cat([input_ids, concat_tokens], dim=1)
1221
+
1222
+ return new_input_ids
1223
+
1224
+ def create_causal_mask_for_kv_cache(
1225
+ self,
1226
+ kv_cache_len: int,
1227
+ num_new_tokens: int,
1228
+ device: torch.device,
1229
+ dtype: torch.dtype = torch.bfloat16
1230
+ ) -> torch.Tensor:
1231
+ """Create causal attention mask for KV cache usage.
1232
+
1233
+ Each new token can only see:
1234
+ 1. All content in KV cache (positions 0 to kv_cache_len-1)
1235
+ 2. Previous new tokens and itself (causal masking)
1236
+
1237
+ Args:
1238
+ kv_cache_len: Length of existing sequence in KV cache
1239
+ num_new_tokens: Number of new tokens being added
1240
+ device: Target device for tensor allocation
1241
+ dtype: Data type for the mask tensor
1242
+
1243
+ Returns:
1244
+ Attention mask of shape (1, 1, num_new_tokens, kv_cache_len + num_new_tokens)
1245
+ """
1246
+ total_len = kv_cache_len + num_new_tokens
1247
+ min_val = torch.finfo(dtype).min
1248
+
1249
+ # Initialize mask with min_val (masked positions)
1250
+ mask = torch.full((num_new_tokens, total_len), min_val, device=device, dtype=dtype)
1251
+
1252
+ # Set visible positions to 0
1253
+ for i in range(num_new_tokens):
1254
+ if kv_cache_len > 0:
1255
+ mask[i, :kv_cache_len] = 0
1256
+ mask[i, kv_cache_len:kv_cache_len + i + 1] = 0
1257
+
1258
+ return mask.unsqueeze(0).unsqueeze(0)
1259
+
1260
+ def create_4d_causal_mask(
1261
+ self,
1262
+ seq_len: int,
1263
+ device: torch.device,
1264
+ dtype: torch.dtype = torch.bfloat16
1265
+ ) -> torch.Tensor:
1266
+ """Create complete 4D causal attention mask for initial decoding.
1267
+
1268
+ Args:
1269
+ seq_len: Sequence length
1270
+ device: Target device for tensor allocation
1271
+ dtype: Data type for the mask tensor
1272
+
1273
+ Returns:
1274
+ Causal attention mask of shape (1, 1, seq_len, seq_len)
1275
+ """
1276
+ min_val = torch.finfo(dtype).min
1277
+
1278
+ # Create lower triangular causal mask
1279
+ mask = torch.full((seq_len, seq_len), min_val, device=device, dtype=dtype)
1280
+ mask = torch.triu(mask, diagonal=1)
1281
+
1282
+ return mask.unsqueeze(0).unsqueeze(0)
1283
+
1284
+ def _first_decoder(
1285
+ self,
1286
+ new_input_ids: torch.Tensor,
1287
+ past_key_values: Optional[DynamicCache] = None,
1288
+ image_embeds: Optional[torch.Tensor] = None,
1289
+ image_mask: Optional[torch.Tensor] = None,
1290
+ num_gen: int = 32
1291
+ ) -> Tuple[torch.Tensor, Any]:
1292
+ """Execute decoder pass with causal attention masking.
1293
+
1294
+ This method performs a single decoder pass with optimized attention masking.
1295
+ On the first decoding step (when past_key_values is None), it processes image
1296
+ embeddings and merges them with text embeddings.
1297
+
1298
+ Args:
1299
+ new_input_ids: Input token IDs of shape (batch_size, seq_len)
1300
+ past_key_values: Cached key-value pairs from previous decoding steps
1301
+ image_embeds: Image embeddings to merge (only used in first step)
1302
+ image_mask: Mask indicating positions for image embedding placement
1303
+ num_gen: Number of tokens to generate in parallel
1304
+
1305
+ Returns:
1306
+ Tuple containing:
1307
+ - predicted_token_ids: Predicted token IDs of shape (batch_size, num_gen)
1308
+ - outputs: Model outputs including logits and updated cache
1309
+ """
1310
+ # Get current sequence position
1311
+ start_position = past_key_values.get_seq_length() if past_key_values is not None else 0
1312
+ batch_size, seq_len = new_input_ids.shape
1313
+
1314
+ # Create position IDs directly on GPU to avoid CPU-GPU transfer
1315
+ position_ids = torch.arange(
1316
+ start_position,
1317
+ start_position + seq_len,
1318
+ dtype=torch.long,
1319
+ device=new_input_ids.device
1320
+ ).unsqueeze(0)
1321
+
1322
+ # Process image embeddings only on first decoding step
1323
+ inputs_embeds = None
1324
+ if start_position == 0:
1325
+ inputs_embeds = self.get_input_idx_embeddings(new_input_ids)
1326
+ if image_embeds is not None:
1327
+ inputs_embeds = self.get_merge_embedding(inputs_embeds, image_embeds, image_mask)
1328
+
1329
+ # Create 4D causal attention mask
1330
+ attention_mask = None
1331
+ if start_position > 0 and seq_len > 0:
1332
+ # When using KV cache, create mask for new tokens
1333
+ attention_mask = self.create_causal_mask_for_kv_cache(
1334
+ start_position, seq_len, new_input_ids.device, dtype=torch.bfloat16
1335
+ )
1336
+ elif start_position == 0 and seq_len > 0:
1337
+ # First decoding, create complete causal mask
1338
+ attention_mask = self.create_4d_causal_mask(
1339
+ seq_len, new_input_ids.device, dtype=torch.bfloat16
1340
+ )
1341
+
1342
+ with torch.no_grad():
1343
+ if start_position > 0:
1344
+ outputs = self.forward(
1345
+ input_ids=new_input_ids,
1346
+ inputs_embeds=None,
1347
+ attention_mask=None, # Note: attention_mask currently disabled
1348
+ position_ids=position_ids,
1349
+ use_cache=True,
1350
+ cache_position=True,
1351
+ past_key_values=past_key_values,
1352
+ )
1353
+ else:
1354
+ outputs = self.forward(
1355
+ input_ids=None,
1356
+ inputs_embeds=inputs_embeds,
1357
+ attention_mask=None, # Note: attention_mask currently disabled
1358
+ position_ids=position_ids,
1359
+ use_cache=True,
1360
+ cache_position=True,
1361
+ past_key_values=past_key_values,
1362
+ )
1363
+
1364
+ # Extract predicted token IDs from logits
1365
+ predicted_token_ids = outputs.logits[:, -(num_gen + 1):-1].argmax(dim=-1)
1366
+
1367
+ return predicted_token_ids, outputs
1368
+
1369
+ def generate_parallel_decoder(
1370
+ self,
1371
+ inputs: Dict[str, torch.Tensor],
1372
+ image_embeds: torch.Tensor,
1373
+ mask_token_id: int,
1374
+ max_new_tokens: int = 8192,
1375
+ num_gen: int = 64,
1376
+ verbose: bool = False
1377
+ ) -> List[int]:
1378
+ """Generate tokens using optimized parallel decoding with dual-pass verification.
1379
+
1380
+ This method implements a parallel decoding strategy that generates multiple tokens
1381
+ simultaneously and verifies them in a second pass. The algorithm:
1382
+ 1. First pass: Predict tokens with mask tokens
1383
+ 2. Second pass: Verify predictions with actual predicted tokens
1384
+ 3. Accept verified tokens and continue from the first unverified position
1385
+
1386
+ Optimizations:
1387
+ - First decoding uses cloned cache to avoid modifying the original
1388
+ - Second decoding updates the original cache in-place
1389
+ - Minimizes CPU-GPU data transfers by operating on GPU
1390
+ - Pre-allocates tensors to avoid repeated creation
1391
+ - Removes debug output from inner loops (controlled by verbose flag)
1392
+ - Entire loop wrapped with torch.no_grad() for efficiency
1393
+
1394
+ Args:
1395
+ inputs: Input dictionary containing 'input_ids' tensor
1396
+ image_embeds: Image embeddings for multimodal processing
1397
+ mask_token_id: Token ID used for masked positions
1398
+ max_new_tokens: Maximum number of tokens to generate
1399
+ num_gen: Number of tokens to generate in parallel per iteration
1400
+ verbose: If True, print detailed progress information
1401
+
1402
+ Returns:
1403
+ List of generated token IDs
1404
+ """
1405
+ if verbose:
1406
+ print("Starting parallel decoder generation")
1407
+
1408
+ # Constants
1409
+ STOP_TOKEN_ID = 128001
1410
+ device = self.model.device
1411
+ input_ids = inputs["input_ids"]
1412
+ decoder_idx = []
1413
+
1414
+ # Pre-allocate mask tokens tensor to avoid repeated creation
1415
+ mask_tokens = torch.full((1, num_gen), mask_token_id, dtype=torch.long, device=device)
1416
+
1417
+ # Initialize KV cache
1418
+ prefix_past_key_values = DynamicCache()
1419
+ step = 0
1420
+ is_exit = False
1421
+
1422
+ # Cache initial token ID
1423
+ prefix_step_id = input_ids[0, 0].item()
1424
+
1425
+ with torch.no_grad():
1426
+ while len(decoder_idx) < max_new_tokens and not is_exit:
1427
+ # ============ First Pass: Predict with mask tokens ============
1428
+ new_input_ids = torch.cat([input_ids, mask_tokens], dim=1)
1429
+
1430
+ # Use cloned cache for first pass to preserve original
1431
+ if step == 0:
1432
+ first_cache = DynamicCache()
1433
+
1434
+ # Create image mask for first step
1435
+ mask = new_input_ids == self.config.image_token_id
1436
+ mask_unsqueezed = mask.unsqueeze(-1)
1437
+ mask_expanded = mask_unsqueezed.expand(-1, -1, image_embeds.size(-1))
1438
+ image_mask = mask_expanded.to(image_embeds.device)
1439
+ else:
1440
+ first_cache = self.clone_past_key_values(prefix_past_key_values)
1441
+
1442
+ first_predicted_ids, _ = self._first_decoder(
1443
+ new_input_ids,
1444
+ past_key_values=first_cache,
1445
+ image_embeds=image_embeds if step == 0 else None,
1446
+ image_mask=image_mask if step == 0 else None,
1447
+ num_gen=num_gen
1448
+ )
1449
+
1450
+ # ============ Second Pass: Verify with predicted tokens ============
1451
+ new_input_ids = torch.cat([input_ids, first_predicted_ids], dim=1)
1452
+
1453
+ # Use original cache for second pass (will be updated and retained)
1454
+ if step == 0:
1455
+ second_cache = DynamicCache()
1456
+ else:
1457
+ second_cache = prefix_past_key_values
1458
+
1459
+ second_predicted_ids, outputs = self._first_decoder(
1460
+ new_input_ids,
1461
+ past_key_values=second_cache,
1462
+ image_embeds=image_embeds if step == 0 else None,
1463
+ image_mask=image_mask if step == 0 else None,
1464
+ num_gen=num_gen
1465
+ )
1466
+
1467
+ # ============ Compare predictions and count successes ============
1468
+ first_pred_list = first_predicted_ids[0].tolist()
1469
+ second_pred_list = second_predicted_ids[0].tolist()
1470
+
1471
+ if verbose:
1472
+ print(f"First pass predictions: {first_pred_list}")
1473
+ print(f"Second pass predictions: {second_pred_list}")
1474
+
1475
+ # Compare predictions to find verified tokens
1476
+ success = 0
1477
+ for idx in range(len(second_pred_list) - 1):
1478
+ first_id = first_pred_list[idx]
1479
+ second_id = second_pred_list[idx]
1480
+ next_second_id = second_pred_list[idx + 1]
1481
+
1482
+ # Check for stop token
1483
+ if second_id == STOP_TOKEN_ID:
1484
+ is_exit = True
1485
+ break
1486
+
1487
+ if next_second_id == STOP_TOKEN_ID and idx == len(second_pred_list) - 2:
1488
+ success += 1
1489
+ is_exit = True
1490
+ break
1491
+
1492
+ # Verify prediction consistency
1493
+ if first_id == second_id:
1494
+ success += 1
1495
+ else:
1496
+ break
1497
+
1498
+ # ============ Update decoded tokens ============
1499
+ if step == 0:
1500
+ decoder_idx.extend(second_pred_list[:success])
1501
+ else:
1502
+ if verbose:
1503
+ print(f"Verified {success} tokens: {second_pred_list[:success]}")
1504
+ decoder_idx.append(prefix_step_id)
1505
+ decoder_idx.extend(second_pred_list[:success])
1506
+
1507
+ if verbose:
1508
+ print(f"Exit status: {is_exit}")
1509
+ print(f"Total decoded tokens: {len(decoder_idx)}")
1510
+
1511
+ # ============ Truncate KV cache to verified length ============
1512
+ past_key_values = outputs.past_key_values
1513
+ if past_key_values is not None:
1514
+ current_kv_len = past_key_values.get_seq_length()
1515
+ num_to_keep = current_kv_len - (num_gen - success)
1516
+ prefix_past_key_values = self.truncate_past_key_values(
1517
+ past_key_values, num_to_keep
1518
+ )
1519
+ else:
1520
+ prefix_past_key_values = None
1521
+
1522
+ # Update input_ids for next iteration
1523
+ next_token_id = (
1524
+ second_pred_list[success]
1525
+ if success < len(second_pred_list)
1526
+ else prefix_step_id
1527
+ )
1528
+ input_ids = torch.tensor(
1529
+ [[next_token_id]],
1530
+ dtype=torch.long,
1531
+ device=device
1532
+ )
1533
+ prefix_step_id = next_token_id
1534
+
1535
+ step += 1
1536
+
1537
+ if verbose:
1538
+ print(f"Step {step} completed, success rate: {success}/{num_gen}\n")
1539
+
1540
+ return decoder_idx
1541
+
1542
+ __all__ = ["YoutuPreTrainedModel", "YoutuModel", "YoutuVLForConditionalGeneration"]
preprocessor_config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_youtu_vl.YoutuVLProcessor",
4
+ "AutoImageProcessor": "image_processing_siglip2_fast.Siglip2ImageProcessorFast"
5
+ },
6
+ "processor_class": "YoutuVLProcessor",
7
+ "do_convert_rgb": null,
8
+ "do_normalize": true,
9
+ "do_rescale": true,
10
+ "do_resize": true,
11
+ "image_mean": [
12
+ 0.5,
13
+ 0.5,
14
+ 0.5
15
+ ],
16
+ "image_processor_type": "Siglip2ImageProcessorFast",
17
+ "image_std": [
18
+ 0.5,
19
+ 0.5,
20
+ 0.5
21
+ ],
22
+ "max_num_patches": 256,
23
+ "patch_size": 16,
24
+ "resample": 2,
25
+ "rescale_factor": 0.00392156862745098
26
+ }
processing_youtu_vl.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+ import numpy
3
+ from transformers.feature_extraction_utils import BatchFeature
4
+ from transformers.image_utils import ImageInput
5
+ from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
6
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
7
+
8
+ class YoutuVLVideosProcessorKwargs(VideosKwargs, total=False):
9
+ fps: Union[List[float], float]
10
+
11
+
12
+ class YoutuVLProcessorKwargs(ProcessingKwargs, total=False):
13
+ videos_kwargs: YoutuVLVideosProcessorKwargs
14
+ _defaults = {
15
+ "text_kwargs": {
16
+ "padding": False,
17
+ },
18
+ "videos_kwargs": {"fps": 2.0},
19
+ }
20
+
21
+
22
+ class YoutuVLProcessor(ProcessorMixin):
23
+
24
+ attributes = ["image_processor", "tokenizer"]
25
+ valid_kwargs = ["chat_template"]
26
+
27
+ image_processor_class = "AutoImageProcessor"
28
+ tokenizer_class = ("PreTrainedTokenizer", "PreTrainedTokenizerFast")
29
+
30
+ def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
31
+ self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
32
+ self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
33
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
34
+
35
+ def __call__(
36
+ self,
37
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
38
+ images: ImageInput = None,
39
+ max_image_patches: int=36864,
40
+ **kwargs: Unpack[YoutuVLProcessorKwargs],
41
+ ) -> BatchFeature:
42
+ """
43
+ Args:
44
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`,
45
+ `List[np.ndarray]`, `List[torch.Tensor]`):
46
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
47
+ tensor. Both channels-first and channels-last formats are supported.
48
+ text (`str`, `List[str]`, `List[List[str]]`):
49
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
50
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
51
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
52
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
53
+ If set, will return tensors of a particular framework. Acceptable values are:
54
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
55
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
56
+ - `'np'`: Return NumPy `np.ndarray` objects.
57
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
58
+
59
+ Returns:
60
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
61
+
62
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
63
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
64
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
65
+ `None`).
66
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
67
+ - **pixel_values_videos** -- Pixel values of videos to be fed to a model.
68
+ Returned when `videos` is not `None`.
69
+ - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
70
+ - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
71
+ - **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`.
72
+ """
73
+ output_kwargs = self._merge_kwargs(
74
+ YoutuVLProcessorKwargs,
75
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
76
+ **kwargs,
77
+ )
78
+ if images is not None:
79
+ image_inputs = self.image_processor(images=images, max_num_patches=max_image_patches, return_tensors="pt")
80
+ else:
81
+ image_inputs = {}
82
+ image_grid_thw = None
83
+
84
+ videos_inputs = {}
85
+ video_grid_thw = None
86
+
87
+ if not isinstance(text, list):
88
+ text = [text]
89
+
90
+ image_tokens = []
91
+ if images is not None:
92
+ merge_length = 4
93
+ index = 0
94
+ for i in range(len(text)):
95
+ while self.image_token in text[i]:
96
+ h = image_inputs['spatial_shapes'][index][0]
97
+ w = image_inputs['spatial_shapes'][index][1]
98
+ repeats = h* w // merge_length
99
+ text[i] = text[i].replace(
100
+ self.image_token,
101
+ "<|placeholder|>" * repeats,
102
+ 1,
103
+ )
104
+ index += 1
105
+ text[i] = text[i].replace("<|placeholder|>", self.image_token)
106
+ assert(index == image_inputs['spatial_shapes'].shape[0])
107
+
108
+
109
+ if video_grid_thw is not None:
110
+ merge_length = self.image_processor.merge_size ** 2
111
+ index = 0
112
+ for i in range(len(text)):
113
+ while self.video_token in text[i]:
114
+ text[i] = text[i].replace(
115
+ self.video_token,
116
+ "<|placeholder|>" * (video_grid_thw[index].prod() // merge_length),
117
+ 1,
118
+ )
119
+ index += 1
120
+ text[i] = text[i].replace("<|placeholder|>", self.video_token)
121
+
122
+ text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
123
+
124
+ return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})
125
+
126
+ def get_max_image_patches(self, images):
127
+ return self.image_processor.get_max_image_patches(images)
128
+
129
+ def batch_decode(self, *args, **kwargs):
130
+ return self.tokenizer.batch_decode(*args, **kwargs)
131
+
132
+ def decode(self, *args, **kwargs):
133
+ return self.tokenizer.decode(*args, **kwargs)
134
+
135
+ def post_process_image_text_to_text(
136
+ self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
137
+ ):
138
+ """
139
+ Post-process the output of the model to decode the text.
140
+
141
+ Args:
142
+ generated_outputs (`torch.Tensor` or `np.ndarray`):
143
+ The output of the model `generate` function. The output is
144
+ expected to be a tensor of shape `(batch_size, sequence_length)`
145
+ or `(sequence_length,)`.
146
+ skip_special_tokens (`bool`, *optional*, defaults to `True`):
147
+ Whether or not to remove special tokens in the output. Argument
148
+ passed to the tokenizer's `batch_decode` method.
149
+ Clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
150
+ Whether or not to clean up the tokenization spaces. Argument
151
+ passed to the tokenizer's `batch_decode` method.
152
+ **kwargs:
153
+ Additional arguments to be passed to the tokenizer's `batch_decode method`.
154
+
155
+ Returns:
156
+ `List[str]`: The decoded text.
157
+ """
158
+ return self.tokenizer.batch_decode(
159
+ generated_outputs,
160
+ skip_special_tokens=skip_special_tokens,
161
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
162
+ **kwargs,
163
+ )
164
+
165
+ @property
166
+ def model_input_names(self):
167
+ tokenizer_input_names = self.tokenizer.model_input_names
168
+ image_processor_input_names = self.image_processor.model_input_names
169
+ names_from_processor = list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
170
+ return names_from_processor + ["second_per_grid_ts"]
171
+
172
+
173
+ __all__ = ["YoutuVLProcessor"]
special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|begin_of_text|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|end_of_text|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|end_of_text|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ea752449128cbb859dc8799d1d40c0a7dc7eac50ed9ccf47cea7a7100dcabe5
3
+ size 19953833
tokenizer_config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|begin_of_text|>",
3
+ "clean_up_tokenization_spaces": false,
4
+ "eos_token": "<|end_of_text|>",
5
+ "extra_special_tokens": {},
6
+ "model_input_names": [
7
+ "input_ids",
8
+ "attention_mask"
9
+ ],
10
+ "model_max_length": 131072,
11
+ "pad_token": "<|end_of_text|>",
12
+ "tokenizer_class": "PreTrainedTokenizerFast",
13
+ "truncation_side": "left",
14
+ "use_fast": true
15
+ }