ValentineKRAFTON commited on
Commit
acd771b
·
verified ·
1 Parent(s): 61289b0

initial commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/photo.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
LICENSE ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to the Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by the Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding any notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ Copyright 2024-2026 Raon Vision Team
180
+
181
+ Licensed under the Apache License, Version 2.0 (the "License");
182
+ you may not use this file except in compliance with the License.
183
+ You may obtain a copy of the License at
184
+
185
+ http://www.apache.org/licenses/LICENSE-2.0
186
+
187
+ Unless required by applicable law or agreed to in writing, software
188
+ distributed under the License is distributed on an "AS IS" BASIS,
189
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
190
+ See the License for the specific language governing permissions and
191
+ limitations under the License.
LICENSES/MIT-OpenAI-CLIP.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 OpenAI
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
LICENSES/MIT-OpenCLIP.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2012-2021 Gabriel Ilharco, Mitchell Wortsman,
4
+ Nicholas Carlini, Rohan Taori, Achal Dave, Vaishaal Shankar,
5
+ John Miller, Hongseok Namkoong, Hannaneh Hajishirzi, Ali Farhadi,
6
+ Ludwig Schmidt
7
+
8
+ Permission is hereby granted, free of charge, to any person obtaining
9
+ a copy of this software and associated documentation files (the
10
+ "Software"), to deal in the Software without restriction, including
11
+ without limitation the rights to use, copy, modify, merge, publish,
12
+ distribute, sublicense, and/or sell copies of the Software, and to
13
+ permit persons to whom the Software is furnished to do so, subject to
14
+ the following conditions:
15
+
16
+ The above copyright notice and this permission notice shall be
17
+ included in all copies or substantial portions of the Software.
18
+
19
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
20
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
21
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
22
+ NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
23
+ LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
24
+ OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
25
+ WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
NOTICE ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ raon-vision-encoder
2
+ Copyright 2024-2026 Raon Vision Team
3
+
4
+ This product includes software derived from the following projects:
5
+
6
+ ===============================================================================
7
+ OpenCLIP
8
+ https://github.com/mlfoundations/open_clip
9
+ Licensed under the MIT License (see LICENSES/MIT-OpenCLIP.txt)
10
+
11
+ Copyright (c) 2012-2021 Gabriel Ilharco, Mitchell Wortsman,
12
+ Nicholas Carlini, Rohan Taori, Achal Dave, Vaishaal Shankar,
13
+ John Miller, Hongseok Namkoong, Hannaneh Hajishirzi, Ali Farhadi,
14
+ Ludwig Schmidt
15
+
16
+ Used in: model/ and train/ packages (LocCa, CLIP, loss, factory,
17
+ transformer, data pipeline, training loop, etc.)
18
+
19
+ ===============================================================================
20
+ OpenAI CLIP
21
+ https://github.com/openai/CLIP
22
+ Licensed under the MIT License (see LICENSES/MIT-OpenAI-CLIP.txt)
23
+
24
+ Copyright (c) 2021 OpenAI
25
+
26
+ Used in: model/tokenizer.py, model/bpe_simple_vocab_16e6.txt.gz
27
+
28
+ ===============================================================================
29
+ Meta Platforms, Inc. (MAE / MoCo v3)
30
+ Licensed under the MIT License via OpenCLIP
31
+
32
+ Copyright (c) Meta Platforms, Inc. and affiliates
33
+
34
+ Used in: model/pos_embed.py (sincos position embedding utilities)
35
+
36
+ ===============================================================================
37
+ timm (pytorch-image-models)
38
+ https://github.com/huggingface/pytorch-image-models
39
+ Licensed under the Apache License 2.0
40
+
41
+ Copyright (c) Ross Wightman
42
+
43
+ Used in: model/transform.py (ResizeKeepRatio)
44
+
45
+ ===============================================================================
46
+ References
47
+
48
+ The following papers informed the design and implementation of features
49
+ in this software. Code was independently implemented unless noted above.
50
+
51
+ - CoCa: Yu et al., "CoCa: Contrastive Captioners are Image-Text Foundation Models", 2022
52
+ - SigLIP: Zhai et al., "Sigmoid Loss for Language Image Pre-Training", 2023
53
+ - SigLIP2: Tschannen et al., "SigLIP 2: Multilingual Vision-Language Encoders", 2025
54
+ - DINO: Caron et al., "Emerging Properties in Self-Supervised Vision Transformers", 2021
55
+ - DINOv2: Oquab et al., "DINOv2: Learning Robust Visual Features without Supervision", 2024
56
+ - SILC: Naeem et al., "SILC: Improving Vision Language Pretraining with Self-Distillation", 2023
57
+ - TIPS: Huang et al., "TIPS: Text-Image Pretraining with Spatial Awareness", 2024
58
+ - Koleo: Sablayrolles et al., "Spreading vectors for similarity search", ICLR 2019
59
+ - Gram Anchoring: Simeoni et al., "DINOv3", 2025 (independently implemented)
60
+ - NaFlex: from SigLIP2 / PaLI (independently implemented in PyTorch)
README.md ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags:
4
+ - vision
5
+ - image-text
6
+ - clip
7
+ - zero-shot
8
+ ---
9
+
10
+ <div align="center">
11
+ <img class="block dark:hidden" src="assets/Raon-VisionEncoder-Gradient-Black.png" alt="Raon VisionEncoder" width="600">
12
+ <img class="hidden dark:block" src="assets/Raon-VisionEncoder-Gradient-White.png" alt="Raon VisionEncoder" width="600">
13
+ </div>
14
+
15
+ <p align="center">
16
+ <a href="https://www.krafton.ai/ko/"><img src="https://img.shields.io/badge/Homepage-KRAFTON%20AI-blue?style=flat&logo=google-chrome&logoColor=white" alt="Homepage"></a>
17
+ <br>
18
+ <a href="https://huggingface.co/KRAFTON"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-KRAFTON-yellow?style=flat" alt="Hugging Face"></a>
19
+ <a href="https://x.com/Krafton_AI"><img src="https://img.shields.io/badge/X-KRAFTON%20AI-white?style=flat&logo=x&logoColor=black" alt="X"></a>
20
+ <br>
21
+ <a href="https://www.apache.org/licenses/LICENSE-2.0"><img src="https://img.shields.io/badge/License-Apache%202.0-lightgrey?style=flat" alt="License"></a>
22
+ </p>
23
+
24
+ **Raon-VisionEncoder** is a 1.14B-parameter vision-language foundation model by [KRAFTON](https://www.krafton.com) for image and text feature extraction.
25
+ It supports zero-shot image classification, image-text retrieval, and native aspect ratio inference via NaFlex.
26
+ Built on [OpenCLIP](https://github.com/mlfoundations/open_clip) with a LocCa (Localized CoCa) architecture and ViT-SO400M vision encoder.
27
+
28
+ ## Pretrained Models
29
+
30
+ | Model | Params (Inference) | Vision | Text | Patch Size | NaFlex Default Patches |
31
+ |-------|--------------------|--------|------|------------|------------------------|
32
+ | LocCa ViT-SO400M-16-SigLIP2 | 1.14B | 0.43B | 0.71B | 16x16 | 256 |
33
+
34
+ ## Requirements
35
+
36
+ ```bash
37
+ pip install torch torchvision timm transformers huggingface-hub safetensors ftfy
38
+ ```
39
+
40
+ ## Quick Start
41
+
42
+ ```python
43
+ import torch
44
+ from transformers import AutoModel
45
+ from PIL import Image
46
+
47
+ # Load model + processor
48
+ model = AutoModel.from_pretrained("KRAFTON/Raon-VisionEncoder", trust_remote_code=True)
49
+ model = model.to(dtype=torch.bfloat16).eval()
50
+ processor = model.get_processor("KRAFTON/Raon-VisionEncoder")
51
+
52
+ # Encode image and text
53
+ img_inputs = processor(images=Image.open("assets/photo.jpg"))
54
+ txt_inputs = processor(text=["a cat", "a dog"])
55
+
56
+ with torch.no_grad():
57
+ img_feat = model.encode_image(**img_inputs)
58
+ txt_feat = model.encode_text(**txt_inputs)
59
+
60
+ # Compute similarity with learned scale and bias
61
+ logits = model.logit_scale.exp() * (img_feat @ txt_feat.T) + model.logit_bias
62
+ probs = logits.softmax(dim=-1)
63
+ print(probs)
64
+ ```
65
+
66
+ ## API Reference
67
+
68
+ | Method | Input | Output |
69
+ |--------|-------|--------|
70
+ | `model.encode_image(**inputs)` | Processor output (image) | `[B, 1152]` normalized image features |
71
+ | `model.encode_text(**inputs)` | Processor output (text) | `[B, 1152]` normalized text features |
72
+ | `model.logit_scale` | - | Learned temperature parameter |
73
+ | `model.logit_bias` | - | Learned bias parameter |
74
+ | `model.get_processor(repo_id)` | HuggingFace repo ID | Processor instance |
75
+ | `processor(images=img)` | PIL Image | Preprocessed image dict |
76
+ | `processor(text=["a cat"])` | list of strings | Tokenized text dict |
77
+
78
+ ## License
79
+
80
+ This repository is licensed under the [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0).
81
+ Third-party notices in [NOTICE](NOTICE).
82
+
83
+ © 2026 KRAFTON
assets/Raon-VisionEncoder-Gradient-Black.png ADDED
assets/Raon-VisionEncoder-Gradient-White.png ADDED
assets/photo.jpg ADDED

Git LFS Details

  • SHA256: 63399d70c550f7e0fb3738d954f04918f0f4a2532e43cc6a1b90ac82d104ed16
  • Pointer size: 131 Bytes
  • Size of remote file: 281 kB
config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "RaonVEModel"
4
+ ],
5
+ "model_type": "raon_ve",
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_raonve.RaonVEConfig",
8
+ "AutoModel": "modeling_raonve.RaonVEModel"
9
+ },
10
+ "embed_dim": 1152,
11
+ "init_logit_bias": -10,
12
+ "vision_config": {
13
+ "image_size": 256,
14
+ "timm_model_name": "vit_so400m_patch16_siglip_256",
15
+ "timm_model_pretrained": false,
16
+ "timm_pool": "map",
17
+ "timm_proj": "none"
18
+ },
19
+ "text_config": {
20
+ "context_length": 64,
21
+ "vocab_size": 256000,
22
+ "hf_tokenizer_name": "timm/ViT-SO400M-16-SigLIP2-256",
23
+ "tokenizer_kwargs": {
24
+ "clean": "canonicalize"
25
+ },
26
+ "width": 1152,
27
+ "heads": 16,
28
+ "layers": 27,
29
+ "mlp_ratio": 3.7362,
30
+ "no_causal_mask": true,
31
+ "proj_bias": true,
32
+ "pool_type": "last",
33
+ "norm_kwargs": {
34
+ "eps": 1e-06
35
+ },
36
+ "act_kwargs": {
37
+ "approximate": "tanh"
38
+ }
39
+ }
40
+ }
configuration_raonve.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Raon-VisionEncoder configuration."""
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+
6
+ class RaonVEVisionConfig(PretrainedConfig):
7
+ model_type = "raon_ve_vision"
8
+
9
+ def __init__(
10
+ self,
11
+ image_size=256,
12
+ timm_model_name="vit_so400m_patch16_siglip_256",
13
+ timm_model_pretrained=False,
14
+ timm_pool="map",
15
+ timm_proj="none",
16
+ **kwargs,
17
+ ):
18
+ super().__init__(**kwargs)
19
+ self.image_size = image_size
20
+ self.timm_model_name = timm_model_name
21
+ self.timm_model_pretrained = timm_model_pretrained
22
+ self.timm_pool = timm_pool
23
+ self.timm_proj = timm_proj
24
+
25
+
26
+ class RaonVETextConfig(PretrainedConfig):
27
+ model_type = "raon_ve_text"
28
+
29
+ def __init__(
30
+ self,
31
+ context_length=64,
32
+ vocab_size=256000,
33
+ width=1152,
34
+ heads=16,
35
+ layers=27,
36
+ mlp_ratio=3.7362,
37
+ no_causal_mask=True,
38
+ proj_bias=True,
39
+ pool_type="last",
40
+ hf_tokenizer_name="timm/ViT-SO400M-16-SigLIP2-256",
41
+ tokenizer_kwargs=None,
42
+ norm_kwargs=None,
43
+ act_kwargs=None,
44
+ **kwargs,
45
+ ):
46
+ super().__init__(**kwargs)
47
+ self.context_length = context_length
48
+ self.vocab_size = vocab_size
49
+ self.width = width
50
+ self.heads = heads
51
+ self.layers = layers
52
+ self.mlp_ratio = mlp_ratio
53
+ self.no_causal_mask = no_causal_mask
54
+ self.proj_bias = proj_bias
55
+ self.pool_type = pool_type
56
+ self.hf_tokenizer_name = hf_tokenizer_name
57
+ self.tokenizer_kwargs = tokenizer_kwargs or {"clean": "canonicalize"}
58
+ self.norm_kwargs = norm_kwargs or {"eps": 1e-6}
59
+ self.act_kwargs = act_kwargs or {"approximate": "tanh"}
60
+
61
+
62
+ class RaonVEConfig(PretrainedConfig):
63
+ model_type = "raon_ve"
64
+ is_composition = True
65
+
66
+ def __init__(
67
+ self,
68
+ embed_dim=1152,
69
+ init_logit_bias=-10,
70
+ vision_config=None,
71
+ text_config=None,
72
+ **kwargs,
73
+ ):
74
+ super().__init__(**kwargs)
75
+ self.embed_dim = embed_dim
76
+ self.init_logit_bias = init_logit_bias
77
+
78
+ if isinstance(vision_config, dict):
79
+ self.vision_config = RaonVEVisionConfig(**vision_config)
80
+ elif vision_config is None:
81
+ self.vision_config = RaonVEVisionConfig()
82
+ else:
83
+ self.vision_config = vision_config
84
+
85
+ if isinstance(text_config, dict):
86
+ self.text_config = RaonVETextConfig(**text_config)
87
+ elif text_config is None:
88
+ self.text_config = RaonVETextConfig()
89
+ else:
90
+ self.text_config = text_config
91
+
92
+ def to_dict(self):
93
+ output = super().to_dict()
94
+ output["vision_config"] = self.vision_config.to_dict()
95
+ output["text_config"] = self.text_config.to_dict()
96
+ return output
modeling_raonve.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Raon-VisionEncoder model."""
2
+
3
+ import importlib
4
+ import os
5
+ import sys
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+ from transformers import PreTrainedModel
10
+
11
+ from .configuration_raonve import RaonVEConfig
12
+
13
+
14
+ _raon_repo_id = None
15
+
16
+ def set_repo_id(repo_id):
17
+ global _raon_repo_id
18
+ _raon_repo_id = repo_id
19
+
20
+ def _ensure_raon_package():
21
+ """Import raon_vision_encoder, downloading from HF Hub if needed."""
22
+ try:
23
+ clip_mod = importlib.import_module("raon_vision_encoder.clip")
24
+ return clip_mod.CustomTextCLIP
25
+ except (ImportError, ModuleNotFoundError):
26
+ pass
27
+
28
+ from huggingface_hub import snapshot_download
29
+ repo_id = _raon_repo_id or "KRAFTON/Raon-VisionEncoder"
30
+ repo_dir = snapshot_download(repo_id, allow_patterns=["raon_vision_encoder/**"])
31
+ sys.path.insert(0, repo_dir)
32
+
33
+ for key in list(sys.modules.keys()):
34
+ if key.startswith("raon_vision_encoder"):
35
+ del sys.modules[key]
36
+
37
+ clip_mod = importlib.import_module("raon_vision_encoder.clip")
38
+ return clip_mod.CustomTextCLIP
39
+
40
+
41
+ class RaonVEPreTrainedModel(PreTrainedModel):
42
+ config_class = RaonVEConfig
43
+ base_model_prefix = ""
44
+ supports_gradient_checkpointing = True
45
+
46
+ def _init_weights(self, module):
47
+ pass
48
+
49
+
50
+ class RaonVEModel(RaonVEPreTrainedModel):
51
+ config_class = RaonVEConfig
52
+
53
+ @classmethod
54
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
55
+ set_repo_id(str(pretrained_model_name_or_path))
56
+ return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
57
+
58
+ def __init__(self, config: RaonVEConfig):
59
+ super().__init__(config)
60
+
61
+ vision_cfg = {
62
+ "image_size": config.vision_config.image_size,
63
+ "timm_model_name": config.vision_config.timm_model_name,
64
+ "timm_model_pretrained": config.vision_config.timm_model_pretrained,
65
+ "timm_pool": config.vision_config.timm_pool,
66
+ "timm_proj": config.vision_config.timm_proj,
67
+ }
68
+ text_cfg = {
69
+ "context_length": config.text_config.context_length,
70
+ "vocab_size": config.text_config.vocab_size,
71
+ "width": config.text_config.width,
72
+ "heads": config.text_config.heads,
73
+ "layers": config.text_config.layers,
74
+ "mlp_ratio": config.text_config.mlp_ratio,
75
+ "no_causal_mask": config.text_config.no_causal_mask,
76
+ "proj_bias": config.text_config.proj_bias,
77
+ "pool_type": config.text_config.pool_type,
78
+ "hf_tokenizer_name": config.text_config.hf_tokenizer_name,
79
+ "tokenizer_kwargs": config.text_config.tokenizer_kwargs,
80
+ "norm_kwargs": config.text_config.norm_kwargs,
81
+ "act_kwargs": config.text_config.act_kwargs,
82
+ }
83
+
84
+ CustomTextCLIP = _ensure_raon_package()
85
+ inner = CustomTextCLIP(
86
+ embed_dim=config.embed_dim,
87
+ vision_cfg=vision_cfg,
88
+ text_cfg=text_cfg,
89
+ init_logit_bias=config.init_logit_bias,
90
+ )
91
+
92
+ self.visual = inner.visual
93
+ self.text = inner.text
94
+ self.logit_scale = inner.logit_scale
95
+ self.logit_bias = inner.logit_bias
96
+
97
+ # Enable NaFlex by default
98
+ self.visual._setup_1d_forward()
99
+
100
+ self.post_init()
101
+
102
+ def encode_image(self, pixel_values, pixel_attention_mask=None, spatial_shapes=None):
103
+ """Encode images to normalized feature vectors [B, 1152].
104
+ Pass the output of processor(images=...) directly via **inputs.
105
+ """
106
+ kwargs = {}
107
+ if pixel_attention_mask is not None:
108
+ kwargs["patch_valid_mask"] = pixel_attention_mask
109
+ if spatial_shapes is not None:
110
+ kwargs["spatial_shapes"] = spatial_shapes
111
+ features = self.visual(pixel_values, **kwargs) if kwargs else self.visual(pixel_values)
112
+ return F.normalize(features, dim=-1)
113
+
114
+ def encode_text(self, input_ids):
115
+ """Encode text to normalized feature vectors [B, 1152].
116
+ Pass the output of processor(text=...) directly via **inputs.
117
+ """
118
+ features = self.text(input_ids)
119
+ return F.normalize(features, dim=-1)
120
+
121
+ def forward(self, pixel_values=None, input_ids=None, pixel_attention_mask=None, spatial_shapes=None):
122
+ image_features = None
123
+ text_features = None
124
+
125
+ if pixel_values is not None:
126
+ image_features = self.encode_image(
127
+ pixel_values,
128
+ pixel_attention_mask=pixel_attention_mask,
129
+ spatial_shapes=spatial_shapes,
130
+ )
131
+ if input_ids is not None:
132
+ text_features = self.encode_text(input_ids)
133
+
134
+ output = {
135
+ "image_features": image_features,
136
+ "text_features": text_features,
137
+ "logit_scale": self.logit_scale,
138
+ "logit_bias": self.logit_bias,
139
+ }
140
+ return output
141
+
142
+ @staticmethod
143
+ def get_processor(pretrained_model_name_or_path, **kwargs):
144
+ """Get the processor for this model."""
145
+ return RaonVEProcessor.from_pretrained(pretrained_model_name_or_path, **kwargs)
146
+
147
+
148
+ class RaonVEProcessor:
149
+ """Image and text processor for Raon-VisionEncoder.
150
+
151
+ Preprocesses images into NaFlex patch sequences and tokenizes text.
152
+
153
+ Args:
154
+ max_num_patches: Maximum number of patches per image (controls resolution).
155
+ Higher values preserve more detail. Default: 256.
156
+ """
157
+
158
+ DEFAULT_MAX_PATCHES = 256
159
+
160
+ def __init__(self, patch_size=16, mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711), tokenizer=None):
161
+ from torchvision import transforms as T
162
+ self.patch_size = patch_size
163
+ self.mean, self.std = mean, std
164
+ self.tokenizer = tokenizer
165
+ self._post = T.Compose([T.ToTensor(), T.Normalize(mean=list(mean), std=list(std))])
166
+
167
+ @classmethod
168
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
169
+ import json
170
+ from pathlib import Path as _Path
171
+ if _Path(pretrained_model_name_or_path).is_dir():
172
+ cfg_path = _Path(pretrained_model_name_or_path) / "config.json"
173
+ else:
174
+ from huggingface_hub import hf_hub_download
175
+ cfg_path = hf_hub_download(pretrained_model_name_or_path, "config.json")
176
+ with open(cfg_path) as f:
177
+ cfg = json.load(f)
178
+ v = cfg.get("vision_config", {}); t = cfg.get("text_config", {})
179
+ ps = 16
180
+ for part in v.get("timm_model_name", "").split("_"):
181
+ if part.startswith("patch") and part[5:].isdigit():
182
+ ps = int(part[5:]); break
183
+ tokenizer = None
184
+ if t.get("hf_tokenizer_name"):
185
+ _ensure_raon_package()
186
+ tok_mod = importlib.import_module("raon_vision_encoder.tokenizer")
187
+ tokenizer = tok_mod.HFTokenizer(
188
+ t["hf_tokenizer_name"], context_length=t.get("context_length", 64),
189
+ tokenizer_mode=t.get("tokenizer_mode"), **t.get("tokenizer_kwargs", {}),
190
+ )
191
+ return cls(patch_size=ps, tokenizer=tokenizer)
192
+
193
+ def __call__(self, images=None, text=None, max_num_patches=None, return_tensors="pt"):
194
+ """Process images and/or text.
195
+
196
+ Args:
197
+ images: PIL Image or list of PIL Images.
198
+ text: String or list of strings.
199
+ max_num_patches: Resolution budget (default: 256). Higher = more detail.
200
+
201
+ Returns:
202
+ Dict with 'pixel_values', 'pixel_attention_mask', 'spatial_shapes' for images
203
+ and/or 'input_ids' for text.
204
+ """
205
+ from PIL import Image
206
+ result = {}
207
+ if images is not None:
208
+ mnp = max_num_patches or self.DEFAULT_MAX_PATCHES
209
+ _ensure_raon_package()
210
+ transform_mod = importlib.import_module("raon_vision_encoder.transform")
211
+ get_size = transform_mod.get_image_size_for_max_num_patches
212
+ imgs = [images] if isinstance(images, Image.Image) else images
213
+ ps = self.patch_size
214
+ all_p, all_m, all_s = [], [], []
215
+ for img in imgs:
216
+ img = img.convert("RGB")
217
+ w, h = img.size
218
+ th, tw = get_size(h, w, ps, mnp)
219
+ t = self._post(img.resize((tw, th), Image.BICUBIC))
220
+ gh, gw = th // ps, tw // ps
221
+ n = gh * gw
222
+ # [C, gh, ps, gw, ps] -> [gh, gw, C, ps, ps] -> [n, C*ps*ps]
223
+ patches = t.reshape(3, gh, ps, gw, ps).permute(1,3,0,2,4).reshape(n, 3*ps*ps)
224
+ padded = torch.zeros(mnp, ps*ps*3); padded[:n] = patches
225
+ mask = torch.zeros(mnp, dtype=torch.bool); mask[:n] = True
226
+ all_p.append(padded); all_m.append(mask)
227
+ all_s.append(torch.tensor([gh, gw]))
228
+ result["pixel_values"] = torch.stack(all_p)
229
+ result["pixel_attention_mask"] = torch.stack(all_m)
230
+ result["spatial_shapes"] = torch.stack(all_s)
231
+ if text is not None:
232
+ if self.tokenizer is None:
233
+ raise RuntimeError("Tokenizer not initialized.")
234
+ result["input_ids"] = self.tokenizer([text] if isinstance(text, str) else text)
235
+ return result
raon_vision_encoder/__init__.py ADDED
File without changes
raon_vision_encoder/clip.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Originally from OpenCLIP (https://github.com/mlfoundations/open_clip)
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+ from functools import partial
11
+
12
+ from .timm_model import TimmModel
13
+ from .transformer import (
14
+ LayerNormFp32,
15
+ LayerNorm,
16
+ QuickGELU,
17
+ TextTransformer,
18
+ text_global_pool,
19
+ )
20
+ from .utils import to_2tuple
21
+
22
+
23
+ @dataclass
24
+ class CLIPVisionCfg:
25
+ layers: Union[Tuple[int, int, int, int], int] = 12
26
+ width: int = 768
27
+ head_width: int = 64
28
+ mlp_ratio: float = 4.0
29
+ patch_size: int = 16
30
+ image_size: Union[Tuple[int, int], int] = 224
31
+
32
+ ls_init_value: Optional[float] = None
33
+ patch_dropout: float = 0.0
34
+ attentional_pool: bool = False
35
+ attn_pooler_queries: int = 256
36
+ attn_pooler_heads: int = 8
37
+ no_ln_pre: bool = False
38
+ pos_embed_type: str = "learnable"
39
+ final_ln_after_pool: bool = False
40
+ pool_type: str = "tok"
41
+ output_tokens: bool = False
42
+ act_kwargs: Optional[dict] = None
43
+ norm_kwargs: Optional[dict] = None
44
+
45
+ block_type: Optional[str] = None
46
+ qk_norm: bool = False
47
+ scaled_cosine_attn: bool = False
48
+ scale_heads: bool = False
49
+ scale_attn_inner: bool = False
50
+ scale_attn: bool = False
51
+ scale_fc: bool = False
52
+
53
+ timm_model_name: Optional[str] = None
54
+ timm_model_pretrained: bool = False
55
+ timm_pool: str = "avg"
56
+ timm_proj: str = "linear"
57
+ timm_proj_bias: bool = False
58
+ timm_drop: float = 0.0
59
+ timm_drop_path: Optional[float] = None
60
+ timm_use_rope: bool = False
61
+ timm_rope_keep_ape: bool = False
62
+ timm_dynamic_img_size: bool = False
63
+ timm_norm_pre: bool = False
64
+
65
+
66
+ @dataclass
67
+ class CLIPTextCfg:
68
+ context_length: int = 77
69
+ vocab_size: int = 49408
70
+ hf_tokenizer_name: Optional[str] = None
71
+ tokenizer_mode: Optional[str] = None
72
+ tokenizer_kwargs: Optional[dict] = None
73
+
74
+ width: int = 512
75
+ heads: int = 8
76
+ layers: int = 12
77
+ mlp_ratio: float = 4.0
78
+ ls_init_value: Optional[float] = None
79
+ embed_cls: bool = False
80
+ pad_id: int = 0
81
+ eos_id: int = 2
82
+ no_causal_mask: bool = False
83
+ final_ln_after_pool: bool = False
84
+ pool_type: str = "argmax"
85
+ proj_bias: bool = False
86
+ proj_type: str = "linear"
87
+ output_tokens: bool = False
88
+ act_kwargs: dict = None
89
+ norm_kwargs: dict = None
90
+
91
+ block_type: Optional[str] = None
92
+ qk_norm: bool = False
93
+ scaled_cosine_attn: bool = False
94
+ scale_heads: bool = False
95
+ scale_attn_inner: bool = False
96
+ scale_attn: bool = False
97
+ scale_fc: bool = False
98
+
99
+ hf_model_name: Optional[str] = None
100
+ hf_model_pretrained: bool = True
101
+ hf_proj_type: str = "mlp"
102
+ hf_pooler_type: str = "mean_pooler"
103
+
104
+
105
+ def get_cast_dtype(precision: str):
106
+ cast_dtype = None
107
+ if precision == "bf16":
108
+ cast_dtype = torch.bfloat16
109
+ elif precision == "fp16":
110
+ cast_dtype = torch.float16
111
+ return cast_dtype
112
+
113
+
114
+ def _build_vision_tower(
115
+ embed_dim: int,
116
+ vision_cfg: CLIPVisionCfg,
117
+ quick_gelu: bool = False,
118
+ cast_dtype: Optional[torch.dtype] = None,
119
+ ):
120
+ if isinstance(vision_cfg, dict):
121
+ vision_cfg = CLIPVisionCfg(**vision_cfg)
122
+
123
+ if not vision_cfg.timm_model_name:
124
+ raise ValueError(
125
+ "Only TimmModel-based vision towers are supported in raon-vision-encoder. "
126
+ "Please set timm_model_name in vision_cfg."
127
+ )
128
+
129
+ visual = TimmModel(
130
+ vision_cfg.timm_model_name,
131
+ pretrained=vision_cfg.timm_model_pretrained,
132
+ pool=vision_cfg.timm_pool,
133
+ proj=vision_cfg.timm_proj,
134
+ proj_bias=vision_cfg.timm_proj_bias,
135
+ drop=vision_cfg.timm_drop,
136
+ drop_path=vision_cfg.timm_drop_path,
137
+ patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None,
138
+ init_values=vision_cfg.ls_init_value,
139
+ qk_norm=vision_cfg.qk_norm,
140
+ use_rope=vision_cfg.timm_use_rope,
141
+ rope_keep_ape=vision_cfg.timm_rope_keep_ape,
142
+ dynamic_img_size=vision_cfg.timm_dynamic_img_size,
143
+ norm_pre=vision_cfg.timm_norm_pre,
144
+ embed_dim=embed_dim,
145
+ image_size=vision_cfg.image_size,
146
+ output_tokens=vision_cfg.output_tokens,
147
+ )
148
+
149
+ return visual
150
+
151
+
152
+ def _build_text_tower(
153
+ embed_dim: int,
154
+ text_cfg: CLIPTextCfg,
155
+ quick_gelu: bool = False,
156
+ cast_dtype: Optional[torch.dtype] = None,
157
+ ):
158
+ if isinstance(text_cfg, dict):
159
+ text_cfg = CLIPTextCfg(**text_cfg)
160
+
161
+ act_layer = QuickGELU if quick_gelu else nn.GELU
162
+ norm_layer = (
163
+ LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
164
+ )
165
+ if text_cfg.norm_kwargs:
166
+ norm_layer = partial(norm_layer, **text_cfg.norm_kwargs)
167
+ if text_cfg.act_kwargs is not None:
168
+ act_layer = partial(act_layer, **text_cfg.act_kwargs)
169
+
170
+ text = TextTransformer(
171
+ context_length=text_cfg.context_length,
172
+ vocab_size=text_cfg.vocab_size,
173
+ width=text_cfg.width,
174
+ heads=text_cfg.heads,
175
+ layers=text_cfg.layers,
176
+ mlp_ratio=text_cfg.mlp_ratio,
177
+ ls_init_value=text_cfg.ls_init_value,
178
+ output_dim=embed_dim,
179
+ embed_cls=text_cfg.embed_cls,
180
+ no_causal_mask=text_cfg.no_causal_mask,
181
+ pad_id=text_cfg.pad_id,
182
+ eos_id=text_cfg.eos_id,
183
+ pool_type=text_cfg.pool_type,
184
+ proj_type=text_cfg.proj_type,
185
+ proj_bias=text_cfg.proj_bias,
186
+ output_tokens=text_cfg.output_tokens,
187
+ act_layer=act_layer,
188
+ norm_layer=norm_layer,
189
+ block_type=text_cfg.block_type,
190
+ qk_norm=text_cfg.qk_norm,
191
+ scaled_cosine_attn=text_cfg.scaled_cosine_attn,
192
+ scale_heads=text_cfg.scale_heads,
193
+ scale_attn_inner=text_cfg.scale_attn_inner,
194
+ scale_attn=text_cfg.scale_attn,
195
+ scale_fc=text_cfg.scale_fc,
196
+ )
197
+ return text
198
+
199
+
200
+ class CustomTextCLIP(nn.Module):
201
+ output_dict: torch.jit.Final[bool]
202
+
203
+ def __init__(
204
+ self,
205
+ embed_dim: int,
206
+ vision_cfg: CLIPVisionCfg,
207
+ text_cfg: CLIPTextCfg,
208
+ quick_gelu: bool = False,
209
+ init_logit_scale: float = np.log(1 / 0.07),
210
+ init_logit_bias: Optional[float] = None,
211
+ nonscalar_logit_scale: bool = False,
212
+ cast_dtype: Optional[torch.dtype] = None,
213
+ output_dict: bool = False,
214
+ ):
215
+ super().__init__()
216
+ self.output_dict = output_dict
217
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
218
+ self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
219
+ self.context_length = self.text.context_length
220
+ self.vocab_size = self.text.vocab_size
221
+
222
+ lshape = [1] if nonscalar_logit_scale else []
223
+ self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
224
+ if init_logit_bias is not None:
225
+ self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
226
+ else:
227
+ self.logit_bias = None
228
+
229
+ def encode_image(
230
+ self, pixel_values, normalize: bool = False, pixel_attention_mask=None, spatial_shapes=None
231
+ ):
232
+ kwargs = {}
233
+ if pixel_attention_mask is not None:
234
+ kwargs["patch_valid_mask"] = pixel_attention_mask
235
+ if spatial_shapes is not None:
236
+ kwargs["spatial_shapes"] = spatial_shapes
237
+ features = self.visual(pixel_values, **kwargs) if kwargs else self.visual(pixel_values)
238
+ return F.normalize(features, dim=-1) if normalize else features
239
+
240
+ def encode_text(self, input_ids, normalize: bool = False):
241
+ features = self.text(input_ids)
242
+ return F.normalize(features, dim=-1) if normalize else features
243
+
244
+ def get_logits(self, image, text):
245
+ image_features = self.encode_image(pixel_values=image, normalize=True)
246
+ text_features = self.encode_text(input_ids=text, normalize=True)
247
+ image_logits = self.logit_scale.exp() * image_features @ text_features.T
248
+ if self.logit_bias is not None:
249
+ image_logits += self.logit_bias
250
+ text_logits = image_logits.T
251
+ return image_logits, text_logits
252
+
253
+ def forward(
254
+ self, image=None, text=None, patch_valid_mask=None, spatial_shapes=None
255
+ ):
256
+ image_features = (
257
+ self.encode_image(
258
+ pixel_values=image,
259
+ normalize=True,
260
+ pixel_attention_mask=patch_valid_mask,
261
+ spatial_shapes=spatial_shapes,
262
+ )
263
+ if image is not None
264
+ else None
265
+ )
266
+ text_features = (
267
+ self.encode_text(input_ids=text, normalize=True) if text is not None else None
268
+ )
269
+
270
+ if self.output_dict:
271
+ out_dict = {
272
+ "image_features": image_features,
273
+ "text_features": text_features,
274
+ "logit_scale": self.logit_scale.exp(),
275
+ }
276
+ if self.logit_bias is not None:
277
+ out_dict["logit_bias"] = self.logit_bias
278
+ return out_dict
279
+
280
+ if self.logit_bias is not None:
281
+ return (
282
+ image_features,
283
+ text_features,
284
+ self.logit_scale.exp(),
285
+ self.logit_bias,
286
+ )
287
+ return image_features, text_features, self.logit_scale.exp()
raon_vision_encoder/constants.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Originally from OpenCLIP (https://github.com/mlfoundations/open_clip)
2
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
3
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
4
+ INCEPTION_MEAN = (0.5, 0.5, 0.5)
5
+ INCEPTION_STD = (0.5, 0.5, 0.5)
raon_vision_encoder/timm_model.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Originally from OpenCLIP (https://github.com/mlfoundations/open_clip)
2
+
3
+ import logging
4
+ import types
5
+ from collections import OrderedDict
6
+ from typing import Optional, Tuple, Union
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ try:
12
+ import timm
13
+ from timm.layers import RotAttentionPool2d
14
+ from timm.layers import AttentionPool2d as AbsAttentionPool2d
15
+ from timm.layers import Mlp, to_2tuple
16
+ from timm.layers import AttentionRope, RotaryEmbeddingCat
17
+ except ImportError:
18
+ timm = None
19
+
20
+
21
+ class TimmModel(nn.Module):
22
+ """timm model adapter"""
23
+
24
+ def __init__(
25
+ self,
26
+ model_name: str,
27
+ embed_dim: int,
28
+ image_size: Union[int, Tuple[int, int]] = 224,
29
+ pool: str = "avg",
30
+ proj: str = "linear",
31
+ proj_bias: bool = False,
32
+ drop: float = 0.0,
33
+ drop_path: Optional[float] = None,
34
+ patch_drop: Optional[float] = None,
35
+ init_values: Optional[float] = None,
36
+ qk_norm: bool = False,
37
+ use_rope: bool = False,
38
+ rope_keep_ape: bool = False,
39
+ dynamic_img_size: bool = False,
40
+ norm_pre: bool = False,
41
+ pretrained: bool = False,
42
+ output_tokens: bool = False,
43
+ ):
44
+ super().__init__()
45
+ if timm is None:
46
+ raise RuntimeError(
47
+ "Please install the latest timm (`pip install timm`) to use timm based models."
48
+ )
49
+ self.image_size = to_2tuple(image_size)
50
+ self.output_tokens = output_tokens
51
+
52
+ timm_kwargs = {}
53
+ if drop_path is not None:
54
+ timm_kwargs["drop_path_rate"] = drop_path
55
+ if patch_drop is not None:
56
+ timm_kwargs["patch_drop_rate"] = patch_drop
57
+ if init_values is not None:
58
+ timm_kwargs["init_values"] = init_values
59
+ if qk_norm:
60
+ timm_kwargs["qk_norm"] = True
61
+ if dynamic_img_size:
62
+ timm_kwargs["dynamic_img_size"] = True
63
+ if use_rope:
64
+
65
+ class _AttentionRopeNoPrefix(AttentionRope):
66
+ """AttentionRope with num_prefix_tokens=0 for models without cls token."""
67
+
68
+ def __init__(self, *args, **kwargs):
69
+ kwargs["num_prefix_tokens"] = 0
70
+ super().__init__(*args, **kwargs)
71
+
72
+ timm_kwargs["attn_layer"] = _AttentionRopeNoPrefix
73
+ if not rope_keep_ape:
74
+ timm_kwargs["pos_embed"] = "none"
75
+
76
+ custom_pool = pool in ("abs_attn", "rot_attn")
77
+ if proj:
78
+ assert proj in ("linear", "mlp", "none")
79
+ extra_proj = proj in ("linear", "mlp")
80
+ if not extra_proj and not custom_pool:
81
+ proj_dim = 0 if proj == "none" else embed_dim
82
+ self.trunk = timm.create_model(
83
+ model_name,
84
+ num_classes=proj_dim,
85
+ global_pool=pool,
86
+ pretrained=pretrained,
87
+ **timm_kwargs,
88
+ )
89
+ prev_chs = embed_dim
90
+ else:
91
+ self.trunk = timm.create_model(
92
+ model_name,
93
+ pretrained=pretrained,
94
+ **timm_kwargs,
95
+ )
96
+ feat_size = self.trunk.default_cfg.get("pool_size", None)
97
+ feature_ndim = 1 if not feat_size else 2
98
+ if custom_pool:
99
+ assert feature_ndim == 2
100
+ self.trunk.reset_classifier(0, global_pool="")
101
+ else:
102
+ reset_kwargs = dict(global_pool=pool) if pool else {}
103
+ self.trunk.reset_classifier(0, **reset_kwargs)
104
+ prev_chs = self.trunk.num_features
105
+
106
+ head_layers = OrderedDict()
107
+
108
+ if pool == "abs_attn":
109
+ head_layers["pool"] = AbsAttentionPool2d(
110
+ prev_chs, feat_size=feat_size, out_features=embed_dim
111
+ )
112
+ prev_chs = embed_dim
113
+ elif pool == "rot_attn":
114
+ head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
115
+ prev_chs = embed_dim
116
+
117
+ if proj == "linear":
118
+ head_layers["drop"] = nn.Dropout(drop)
119
+ head_layers["proj"] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
120
+ elif proj == "mlp":
121
+ head_layers["mlp"] = Mlp(
122
+ prev_chs,
123
+ 2 * embed_dim,
124
+ embed_dim,
125
+ drop=(drop, 0),
126
+ bias=(True, proj_bias),
127
+ )
128
+
129
+ self.head = nn.Sequential(head_layers)
130
+
131
+ if (
132
+ norm_pre
133
+ and hasattr(self.trunk, "norm_pre")
134
+ and isinstance(self.trunk.norm_pre, nn.Identity)
135
+ ):
136
+ self.trunk.norm_pre = nn.LayerNorm(self.trunk.embed_dim)
137
+ logging.info(
138
+ f"Replaced norm_pre Identity with LayerNorm({self.trunk.embed_dim})"
139
+ )
140
+
141
+ self._has_rope = use_rope
142
+ if use_rope:
143
+ self._setup_rope()
144
+
145
+ def _setup_rope(self):
146
+ """Inject 2D Rotary Position Embedding into the timm trunk."""
147
+ num_heads = self.trunk.blocks[0].attn.num_heads
148
+ head_dim = self.trunk.embed_dim // num_heads
149
+
150
+ self.trunk.patch_embed.strict_img_size = False
151
+
152
+ self.rope = RotaryEmbeddingCat(
153
+ dim=head_dim,
154
+ max_res=max(self.image_size),
155
+ in_pixels=True,
156
+ )
157
+
158
+ def _block_forward_rope(block_self, x, rope=None, attn_mask=None):
159
+ x = x + block_self.drop_path1(
160
+ block_self.ls1(
161
+ block_self.attn(block_self.norm1(x), rope=rope, attn_mask=attn_mask)
162
+ )
163
+ )
164
+ x = x + block_self.drop_path2(
165
+ block_self.ls2(block_self.mlp(block_self.norm2(x)))
166
+ )
167
+ return x
168
+
169
+ for blk in self.trunk.blocks:
170
+ blk.forward = types.MethodType(_block_forward_rope, blk)
171
+
172
+ timm_model_ref = self
173
+ _num_prefix = getattr(self.trunk, "num_prefix_tokens", 0)
174
+
175
+ def _forward_features_rope(trunk_self, x, attn_mask=None):
176
+ from torch.utils.checkpoint import checkpoint
177
+ from timm.layers import resample_abs_pos_embed
178
+
179
+ ps = trunk_self.patch_embed.patch_size
180
+ grid_shape = [x.shape[2] // ps[0], x.shape[3] // ps[1]]
181
+
182
+ x = trunk_self.patch_embed(x)
183
+ if x.ndim == 4:
184
+ x = x.reshape(x.shape[0], -1, x.shape[-1])
185
+ if hasattr(trunk_self, "pos_embed") and trunk_self.pos_embed is not None:
186
+ if x.shape[1] != trunk_self.pos_embed.shape[1]:
187
+ x = x + resample_abs_pos_embed(
188
+ trunk_self.pos_embed, grid_shape, num_prefix_tokens=_num_prefix
189
+ )
190
+ else:
191
+ x = x + trunk_self.pos_embed
192
+ x = trunk_self.pos_drop(x)
193
+ x = trunk_self.norm_pre(x)
194
+
195
+ rot_pos_embed = timm_model_ref.rope.get_embed(shape=grid_shape)
196
+
197
+ _sdpa_mask = None
198
+ if attn_mask is not None:
199
+ _sdpa_mask = torch.zeros_like(attn_mask, dtype=x.dtype)
200
+ _sdpa_mask.masked_fill_(~attn_mask, float("-inf"))
201
+ _sdpa_mask = _sdpa_mask.unsqueeze(1).unsqueeze(2)
202
+
203
+ for blk in trunk_self.blocks:
204
+ if trunk_self.grad_checkpointing and not torch.jit.is_scripting():
205
+ x = checkpoint(
206
+ blk,
207
+ x,
208
+ rope=rot_pos_embed,
209
+ attn_mask=_sdpa_mask,
210
+ use_reentrant=False,
211
+ )
212
+ else:
213
+ x = blk(x, rope=rot_pos_embed, attn_mask=_sdpa_mask)
214
+
215
+ x = trunk_self.norm(x)
216
+ return x
217
+
218
+ self.trunk.forward_features = types.MethodType(
219
+ _forward_features_rope, self.trunk
220
+ )
221
+
222
+ def _setup_dynamic_pos_embed(self):
223
+ """Patch forward_features for variable-resolution pos_embed interpolation (non-RoPE)."""
224
+ self.trunk.patch_embed.strict_img_size = False
225
+ _num_prefix = getattr(self.trunk, "num_prefix_tokens", 0)
226
+
227
+ def _forward_features_dynamic(trunk_self, x, patch_valid_mask=None):
228
+ from torch.utils.checkpoint import checkpoint
229
+ from timm.layers import resample_abs_pos_embed
230
+
231
+ ps = trunk_self.patch_embed.patch_size
232
+ grid_shape = [x.shape[2] // ps[0], x.shape[3] // ps[1]]
233
+
234
+ x = trunk_self.patch_embed(x)
235
+ if x.ndim == 4:
236
+ x = x.reshape(x.shape[0], -1, x.shape[-1])
237
+ if hasattr(trunk_self, "pos_embed") and trunk_self.pos_embed is not None:
238
+ if x.shape[1] != trunk_self.pos_embed.shape[1]:
239
+ x = x + resample_abs_pos_embed(
240
+ trunk_self.pos_embed, grid_shape, num_prefix_tokens=_num_prefix
241
+ )
242
+ else:
243
+ x = x + trunk_self.pos_embed
244
+ x = trunk_self.pos_drop(x)
245
+ x = trunk_self.norm_pre(x)
246
+
247
+ _sdpa_mask = None
248
+ if patch_valid_mask is not None:
249
+ _sdpa_mask = torch.zeros_like(patch_valid_mask, dtype=x.dtype)
250
+ _sdpa_mask.masked_fill_(~patch_valid_mask, float("-inf"))
251
+ _sdpa_mask = _sdpa_mask.unsqueeze(1).unsqueeze(2)
252
+
253
+ for blk in trunk_self.blocks:
254
+ if trunk_self.grad_checkpointing and not torch.jit.is_scripting():
255
+ if _sdpa_mask is not None:
256
+ x = checkpoint(
257
+ blk, x, attn_mask=_sdpa_mask, use_reentrant=False
258
+ )
259
+ else:
260
+ x = checkpoint(blk, x, use_reentrant=False)
261
+ else:
262
+ x = blk(x, attn_mask=_sdpa_mask)
263
+
264
+ x = trunk_self.norm(x)
265
+ return x
266
+
267
+ self.trunk.forward_features = types.MethodType(
268
+ _forward_features_dynamic, self.trunk
269
+ )
270
+
271
+ def _setup_1d_forward(self):
272
+ """Patch forward_features for NaFlex 1D mode (SigLIP2 style)."""
273
+ _num_prefix = getattr(self.trunk, "num_prefix_tokens", 0)
274
+
275
+ def _forward_features_1d(
276
+ trunk_self, x, patch_valid_mask=None, spatial_shapes=None
277
+ ):
278
+ from torch.utils.checkpoint import checkpoint
279
+
280
+ conv = trunk_self.patch_embed.proj
281
+ D = conv.weight.shape[0]
282
+ x = torch.nn.functional.linear(
283
+ x.to(conv.weight.dtype), conv.weight.reshape(D, -1), conv.bias
284
+ )
285
+
286
+ if (
287
+ hasattr(trunk_self, "pos_embed")
288
+ and trunk_self.pos_embed is not None
289
+ and spatial_shapes is not None
290
+ ):
291
+ pos_embed = trunk_self.pos_embed
292
+ base_n = pos_embed.shape[1]
293
+ base_grid = int(base_n**0.5)
294
+ pos_2d = (
295
+ pos_embed.reshape(1, base_grid, base_grid, -1)
296
+ .permute(0, 3, 1, 2)
297
+ .float()
298
+ )
299
+
300
+ B, sl, D_emb = x.shape
301
+ pos_resized = torch.zeros(B, sl, D_emb, device=x.device, dtype=x.dtype)
302
+
303
+ for i in range(B):
304
+ gh, gw = spatial_shapes[i].tolist()
305
+ pe = torch.nn.functional.interpolate(
306
+ pos_2d, size=(gh, gw), mode="bilinear", align_corners=False
307
+ )
308
+ pe = pe.squeeze(0).permute(1, 2, 0).reshape(gh * gw, -1).to(x.dtype)
309
+ n_patches = gh * gw
310
+ pos_resized[i, :n_patches] = pe
311
+ if n_patches < sl:
312
+ pos_resized[i, n_patches:] = pe[0]
313
+
314
+ x = x + pos_resized
315
+ elif hasattr(trunk_self, "pos_embed") and trunk_self.pos_embed is not None:
316
+ x = x + trunk_self.pos_embed
317
+
318
+ x = trunk_self.pos_drop(x)
319
+ x = trunk_self.norm_pre(x)
320
+
321
+ _sdpa_mask = None
322
+ if patch_valid_mask is not None:
323
+ _sdpa_mask = torch.zeros_like(patch_valid_mask, dtype=x.dtype)
324
+ _sdpa_mask.masked_fill_(~patch_valid_mask, float("-inf"))
325
+ _sdpa_mask = _sdpa_mask.unsqueeze(1).unsqueeze(2)
326
+
327
+ for blk in trunk_self.blocks:
328
+ if trunk_self.grad_checkpointing and not torch.jit.is_scripting():
329
+ if _sdpa_mask is not None:
330
+ x = checkpoint(
331
+ blk, x, attn_mask=_sdpa_mask, use_reentrant=False
332
+ )
333
+ else:
334
+ x = checkpoint(blk, x, use_reentrant=False)
335
+ else:
336
+ x = blk(x, attn_mask=_sdpa_mask)
337
+
338
+ x = trunk_self.norm(x)
339
+ return x
340
+
341
+ self.trunk._forward_features_1d = types.MethodType(
342
+ _forward_features_1d, self.trunk
343
+ )
344
+ self._has_1d_forward = True
345
+
346
+ def forward_patch_features(self, x):
347
+ """Forward pass returning per-patch features (before pooling/projection)."""
348
+ return self.trunk.forward_features(x)
349
+
350
+ def forward(self, x, patch_valid_mask=None, spatial_shapes=None):
351
+ if spatial_shapes is not None and getattr(self, "_has_1d_forward", False):
352
+ patch_features = self.trunk._forward_features_1d(
353
+ x, patch_valid_mask=patch_valid_mask, spatial_shapes=spatial_shapes
354
+ )
355
+ elif patch_valid_mask is not None and self._has_rope:
356
+ patch_features = self.trunk.forward_features(x, attn_mask=patch_valid_mask)
357
+ elif patch_valid_mask is not None:
358
+ patch_features = self.trunk.forward_features(
359
+ x, patch_valid_mask=patch_valid_mask
360
+ )
361
+ else:
362
+ patch_features = self.trunk.forward_features(x)
363
+ if patch_valid_mask is not None:
364
+ mask_f = patch_valid_mask.unsqueeze(-1).to(
365
+ patch_features.dtype
366
+ )
367
+ patch_features = patch_features * mask_f
368
+ self._cached_patch_features = patch_features
369
+ if (
370
+ patch_valid_mask is not None
371
+ and getattr(self.trunk, "global_pool", "") == "avg"
372
+ ):
373
+ pooled = patch_features.sum(dim=1) / mask_f.sum(dim=1).clamp(min=1)
374
+ pooled = (
375
+ self.trunk.fc_norm(pooled) if hasattr(self.trunk, "fc_norm") else pooled
376
+ )
377
+ elif (
378
+ patch_valid_mask is not None
379
+ and getattr(self.trunk, "attn_pool", None) is not None
380
+ ):
381
+ attn_mask = torch.zeros(
382
+ patch_valid_mask.shape,
383
+ dtype=patch_features.dtype,
384
+ device=patch_features.device,
385
+ )
386
+ attn_mask.masked_fill_(~patch_valid_mask.bool(), float("-inf"))
387
+ attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)
388
+ pooled = self.trunk.attn_pool(patch_features, attn_mask=attn_mask)
389
+ pooled = (
390
+ self.trunk.fc_norm(pooled) if hasattr(self.trunk, "fc_norm") else pooled
391
+ )
392
+ else:
393
+ pooled = self.trunk.forward_head(patch_features)
394
+ pooled = self.head(pooled)
395
+ if self.output_tokens:
396
+ return pooled, patch_features
397
+ return pooled
raon_vision_encoder/tokenizer.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Originally from OpenCLIP (https://github.com/mlfoundations/open_clip)
2
+
3
+ import html
4
+ import os
5
+ import string
6
+ from typing import List, Optional, Union
7
+ import warnings
8
+
9
+ try:
10
+ import ftfy
11
+ except ImportError:
12
+ ftfy = None
13
+ import torch
14
+
15
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
16
+
17
+ DEFAULT_CONTEXT_LENGTH = 77
18
+
19
+
20
+ def basic_clean(text):
21
+ if ftfy is not None:
22
+ text = ftfy.fix_text(text)
23
+ else:
24
+ text
25
+ text = html.unescape(html.unescape(text))
26
+ return text.strip()
27
+
28
+
29
+ def whitespace_clean(text):
30
+ text = " ".join(text.split())
31
+ text = text.strip()
32
+ return text
33
+
34
+
35
+ def _clean_canonicalize(x):
36
+ return canonicalize_text(basic_clean(x))
37
+
38
+
39
+ def _clean_lower(x):
40
+ return whitespace_clean(basic_clean(x)).lower()
41
+
42
+
43
+ def _clean_whitespace(x):
44
+ return whitespace_clean(basic_clean(x))
45
+
46
+
47
+ def get_clean_fn(type: str):
48
+ if type == "canonicalize":
49
+ return _clean_canonicalize
50
+ elif type == "lower":
51
+ return _clean_lower
52
+ elif type == "whitespace":
53
+ return _clean_whitespace
54
+ else:
55
+ assert False, f"Invalid clean function ({type})."
56
+
57
+
58
+ def canonicalize_text(
59
+ text,
60
+ *,
61
+ keep_punctuation_exact_string=None,
62
+ trans_punctuation: dict = str.maketrans("", "", string.punctuation),
63
+ ):
64
+ """Returns canonicalized `text` (lowercase and punctuation removed)."""
65
+ text = text.replace("_", " ")
66
+ if keep_punctuation_exact_string:
67
+ text = keep_punctuation_exact_string.join(
68
+ part.translate(trans_punctuation)
69
+ for part in text.split(keep_punctuation_exact_string)
70
+ )
71
+ else:
72
+ text = text.translate(trans_punctuation)
73
+ text = text.lower()
74
+ text = " ".join(text.split())
75
+ return text.strip()
76
+
77
+
78
+ class HFTokenizer:
79
+ """HuggingFace tokenizer wrapper with support for custom tokenization modes"""
80
+
81
+ def __init__(
82
+ self,
83
+ tokenizer_name: str,
84
+ context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,
85
+ clean: str = "whitespace",
86
+ strip_sep_token: bool = False,
87
+ language: Optional[str] = None,
88
+ cache_dir: Optional[str] = None,
89
+ tokenizer_mode: Optional[str] = None,
90
+ **kwargs,
91
+ ):
92
+ self.tokenizer_mode = tokenizer_mode or ""
93
+ self.context_length = context_length
94
+ self.clean_fn = get_clean_fn(clean)
95
+ self.strip_sep_token = strip_sep_token
96
+
97
+ from transformers import AutoTokenizer
98
+
99
+ self.tokenizer = AutoTokenizer.from_pretrained(
100
+ tokenizer_name, cache_dir=cache_dir, **kwargs
101
+ )
102
+
103
+ set_lang_fn = getattr(self.tokenizer, "set_src_lang_special_tokens", None)
104
+ if callable(set_lang_fn):
105
+ self.set_lang_fn = set_lang_fn
106
+ if language is not None:
107
+ self.set_language(language)
108
+
109
+ def save_pretrained(self, dest):
110
+ self.tokenizer.save_pretrained(dest)
111
+
112
+ def __call__(
113
+ self, texts: Union[str, List[str]], context_length: Optional[int] = None
114
+ ) -> torch.Tensor:
115
+ if isinstance(texts, str):
116
+ texts = [texts]
117
+
118
+ context_length = context_length or self.context_length
119
+ assert context_length, (
120
+ "Please set a valid context length in class init or call."
121
+ )
122
+
123
+ texts = [self.clean_fn(text) for text in texts]
124
+
125
+ if self.tokenizer_mode == "clips":
126
+ return self._clips_tokenize(texts, context_length)
127
+ else:
128
+ output = self.tokenizer(
129
+ texts,
130
+ return_tensors="pt",
131
+ max_length=context_length,
132
+ padding="max_length",
133
+ truncation=True,
134
+ )
135
+ input_ids = output.input_ids
136
+
137
+ if self.strip_sep_token:
138
+ input_ids = torch.where(
139
+ input_ids == self.tokenizer.sep_token_id,
140
+ torch.zeros_like(input_ids),
141
+ input_ids,
142
+ )
143
+
144
+ return input_ids
145
+
146
+ def set_language(self, src_lang):
147
+ if hasattr(self, "set_lang_fn"):
148
+ self.set_lang_fn(src_lang)
149
+ else:
150
+ warnings.warn("Cannot set language for the tokenizer.")
151
+
152
+ def _clips_tokenize(self, texts: List[str], context_length: int) -> torch.Tensor:
153
+ encoded_outputs = self.tokenizer(
154
+ texts,
155
+ add_special_tokens=False,
156
+ padding=False,
157
+ truncation=False,
158
+ return_tensors=None,
159
+ )
160
+
161
+ encoded = []
162
+ for tokens in encoded_outputs["input_ids"]:
163
+ tokens = tokens[: context_length - 3]
164
+ tokens = (
165
+ [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
166
+ )
167
+ encoded.append(tokens)
168
+
169
+ result = torch.zeros(len(encoded), context_length, dtype=torch.long)
170
+ for i, tokens in enumerate(encoded):
171
+ padded_tokens = self._pad_and_add_class_token(
172
+ tokens,
173
+ max_length=context_length,
174
+ pad_token_id=self.tokenizer.pad_token_id,
175
+ cls_token_id=self.tokenizer.cls_token_id,
176
+ )
177
+ result[i, : len(padded_tokens)] = torch.tensor(padded_tokens)
178
+
179
+ return result
180
+
181
+ def _pad_and_add_class_token(
182
+ self,
183
+ tokens: List[int],
184
+ max_length: int,
185
+ pad_token_id: int = 0,
186
+ cls_token_id: int = 101,
187
+ ) -> List[int]:
188
+ if len(tokens) > max_length - 1:
189
+ tokens = tokens[: max_length - 1]
190
+ if len(tokens) < max_length - 1:
191
+ tokens = tokens + [pad_token_id] * (max_length - 1 - len(tokens))
192
+ tokens = tokens + [cls_token_id]
193
+ return tokens
raon_vision_encoder/transform.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Originally from OpenCLIP (https://github.com/mlfoundations/open_clip)
2
+
3
+ import math
4
+
5
+
6
+ def get_image_size_for_max_num_patches(
7
+ image_height, image_width, patch_size, max_num_patches
8
+ ):
9
+ """Find target image size preserving aspect ratio within patch budget.
10
+
11
+ Uses binary search to find the optimal scale such that
12
+ ceil(h*scale/ps)*ceil(w*scale/ps) <= max_num_patches.
13
+
14
+ Args:
15
+ image_height: Original image height.
16
+ image_width: Original image width.
17
+ patch_size: Patch size (int).
18
+ max_num_patches: Maximum number of patches allowed.
19
+
20
+ Returns:
21
+ (target_h, target_w) both multiples of patch_size.
22
+ """
23
+ scale_min, scale_max = 1e-6, 100.0
24
+ eps = 1e-5
25
+ while (scale_max - scale_min) >= eps:
26
+ scale = (scale_min + scale_max) / 2
27
+ target_h = max(
28
+ patch_size, int(math.ceil(image_height * scale / patch_size) * patch_size)
29
+ )
30
+ target_w = max(
31
+ patch_size, int(math.ceil(image_width * scale / patch_size) * patch_size)
32
+ )
33
+ num_patches = (target_h // patch_size) * (target_w // patch_size)
34
+ if num_patches <= max_num_patches:
35
+ scale_min = scale
36
+ else:
37
+ scale_max = scale
38
+ target_h = max(
39
+ patch_size, int(math.ceil(image_height * scale_min / patch_size) * patch_size)
40
+ )
41
+ target_w = max(
42
+ patch_size, int(math.ceil(image_width * scale_min / patch_size) * patch_size)
43
+ )
44
+ return target_h, target_w
raon_vision_encoder/transformer.py ADDED
@@ -0,0 +1,627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Originally from OpenCLIP (https://github.com/mlfoundations/open_clip)
2
+
3
+ from collections import OrderedDict
4
+ import math
5
+ from typing import Callable, Optional, Type, Union
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+ from torch.utils.checkpoint import checkpoint
11
+
12
+
13
+ class LayerNormFp32(nn.LayerNorm):
14
+ """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
15
+
16
+ def forward(self, x: torch.Tensor):
17
+ orig_type = x.dtype
18
+ x = F.layer_norm(
19
+ x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps
20
+ )
21
+ return x.to(orig_type)
22
+
23
+
24
+ class LayerNorm(nn.LayerNorm):
25
+ """Subclass torch's LayerNorm (with cast back to input dtype)."""
26
+
27
+ def forward(self, x: torch.Tensor):
28
+ orig_type = x.dtype
29
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
30
+ return x.to(orig_type)
31
+
32
+
33
+ class QuickGELU(nn.Module):
34
+ # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
35
+ def forward(self, x: torch.Tensor):
36
+ return x * torch.sigmoid(1.702 * x)
37
+
38
+
39
+ class LayerScale(nn.Module):
40
+ def __init__(self, dim, init_values=1e-5, inplace=False):
41
+ super().__init__()
42
+ self.inplace = inplace
43
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
44
+
45
+ def forward(self, x):
46
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
47
+
48
+
49
+ class Attention(nn.Module):
50
+ def __init__(
51
+ self,
52
+ dim: int,
53
+ num_heads: int = 8,
54
+ qkv_bias: bool = True,
55
+ qk_norm: bool = False,
56
+ scaled_cosine: bool = False,
57
+ scale_heads: bool = False,
58
+ inner_norm: bool = False,
59
+ logit_scale_max: float = math.log(1.0 / 0.01),
60
+ norm_layer: Type[nn.Module] = LayerNormFp32,
61
+ attn_drop: float = 0.0,
62
+ proj_drop: float = 0.0,
63
+ ):
64
+ super().__init__()
65
+ assert not (scaled_cosine and qk_norm), (
66
+ "Cannot activate both scaled cosine and QK normalization"
67
+ )
68
+ self.scaled_cosine = scaled_cosine
69
+ self.scale_heads = scale_heads
70
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
71
+ self.num_heads = num_heads
72
+ self.head_dim = dim // num_heads
73
+ self.scale = self.head_dim**-0.5
74
+ self.logit_scale_max = logit_scale_max
75
+ self.use_fsdpa = hasattr(nn.functional, "scaled_dot_product_attention")
76
+
77
+ self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
78
+ if qkv_bias:
79
+ self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
80
+ else:
81
+ self.in_proj_bias = None
82
+
83
+ if qk_norm:
84
+ self.ln_q = norm_layer(self.head_dim)
85
+ self.ln_k = norm_layer(self.head_dim)
86
+ else:
87
+ self.ln_q = nn.Identity()
88
+ self.ln_k = nn.Identity()
89
+
90
+ if self.scaled_cosine:
91
+ self.logit_scale = nn.Parameter(
92
+ torch.log(10 * torch.ones((num_heads, 1, 1)))
93
+ )
94
+ else:
95
+ self.logit_scale = None
96
+
97
+ self.attn_drop = nn.Dropout(attn_drop)
98
+
99
+ if self.scale_heads:
100
+ self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
101
+ else:
102
+ self.head_scale = None
103
+
104
+ if inner_norm:
105
+ self.ln_inner = norm_layer(dim)
106
+ else:
107
+ self.ln_inner = nn.Identity()
108
+
109
+ self.out_proj = nn.Linear(dim, dim)
110
+ self.out_drop = nn.Dropout(proj_drop)
111
+
112
+ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
113
+ N, L, C = x.shape
114
+ q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
115
+ q = q.reshape(N, L, self.num_heads, -1).transpose(1, 2)
116
+ k = k.reshape(N, L, self.num_heads, -1).transpose(1, 2)
117
+ v = v.reshape(N, L, self.num_heads, -1).transpose(1, 2)
118
+
119
+ if attn_mask is not None:
120
+ if attn_mask.ndim == 3:
121
+ attn_mask = attn_mask.reshape(N, self.num_heads, L, L)
122
+ if attn_mask.dtype == torch.bool:
123
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
124
+ new_attn_mask.masked_fill_(attn_mask, float("-inf"))
125
+ attn_mask = new_attn_mask
126
+ else:
127
+ attn_mask = attn_mask.to(dtype=q.dtype)
128
+
129
+ if self.logit_scale is not None:
130
+ attn = torch.bmm(
131
+ F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)
132
+ )
133
+ logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
134
+ attn = attn * logit_scale
135
+ if attn_mask is not None:
136
+ attn = attn + attn_mask
137
+ attn = attn.softmax(dim=-1)
138
+ attn = self.attn_drop(attn)
139
+ x = torch.bmm(attn, v)
140
+ else:
141
+ q = self.ln_q(q)
142
+ k = self.ln_k(k)
143
+ if self.use_fsdpa:
144
+ x = F.scaled_dot_product_attention(
145
+ q,
146
+ k,
147
+ v,
148
+ attn_mask=attn_mask,
149
+ dropout_p=self.attn_drop.p if self.training else 0.0,
150
+ )
151
+ else:
152
+ q = q * self.scale
153
+ attn = torch.bmm(q, k.transpose(-1, -2))
154
+ if attn_mask is not None:
155
+ attn += attn_mask
156
+ attn = attn.softmax(dim=-1)
157
+ attn = self.attn_drop(attn)
158
+ x = torch.bmm(attn, v)
159
+
160
+ if self.head_scale is not None:
161
+ x = x * self.head_scale
162
+ x = x.transpose(1, 2).reshape(N, L, C)
163
+ x = self.ln_inner(x)
164
+ x = self.out_proj(x)
165
+ x = self.out_drop(x)
166
+ return x
167
+
168
+
169
+ class ResidualAttentionBlock(nn.Module):
170
+ def __init__(
171
+ self,
172
+ d_model: int,
173
+ n_head: int,
174
+ mlp_ratio: float = 4.0,
175
+ ls_init_value: float = None,
176
+ act_layer: Callable = nn.GELU,
177
+ norm_layer: Callable = LayerNorm,
178
+ is_cross_attention: bool = False,
179
+ batch_first: bool = True,
180
+ ):
181
+ super().__init__()
182
+
183
+ self.ln_1 = norm_layer(d_model)
184
+ self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=batch_first)
185
+ self.ls_1 = (
186
+ LayerScale(d_model, ls_init_value)
187
+ if ls_init_value is not None
188
+ else nn.Identity()
189
+ )
190
+ if is_cross_attention:
191
+ self.ln_1_kv = norm_layer(d_model)
192
+
193
+ self.ln_2 = norm_layer(d_model)
194
+ mlp_width = int(d_model * mlp_ratio)
195
+ self.mlp = nn.Sequential(
196
+ OrderedDict(
197
+ [
198
+ ("c_fc", nn.Linear(d_model, mlp_width)),
199
+ ("gelu", act_layer()),
200
+ ("c_proj", nn.Linear(mlp_width, d_model)),
201
+ ]
202
+ )
203
+ )
204
+ self.ls_2 = (
205
+ LayerScale(d_model, ls_init_value)
206
+ if ls_init_value is not None
207
+ else nn.Identity()
208
+ )
209
+
210
+ def get_weight_dtype(self) -> torch.dtype:
211
+ if hasattr(self.mlp.c_fc, "int8_original_dtype"):
212
+ return self.mlp.c_fc.int8_original_dtype
213
+ return self.mlp.c_fc.weight.dtype
214
+
215
+ def attention(
216
+ self,
217
+ q_x: torch.Tensor,
218
+ k_x: Optional[torch.Tensor] = None,
219
+ v_x: Optional[torch.Tensor] = None,
220
+ attn_mask: Optional[torch.Tensor] = None,
221
+ key_padding_mask: Optional[torch.Tensor] = None,
222
+ ):
223
+ k_x = k_x if k_x is not None else q_x
224
+ v_x = v_x if v_x is not None else q_x
225
+
226
+ attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
227
+ return self.attn(
228
+ q_x,
229
+ k_x,
230
+ v_x,
231
+ need_weights=False,
232
+ attn_mask=attn_mask,
233
+ key_padding_mask=key_padding_mask,
234
+ )[0]
235
+
236
+ def forward(
237
+ self,
238
+ q_x: torch.Tensor,
239
+ k_x: Optional[torch.Tensor] = None,
240
+ v_x: Optional[torch.Tensor] = None,
241
+ attn_mask: Optional[torch.Tensor] = None,
242
+ key_padding_mask: Optional[torch.Tensor] = None,
243
+ ):
244
+ k_x = (
245
+ self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
246
+ )
247
+ v_x = (
248
+ self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
249
+ )
250
+ x = q_x + self.ls_1(
251
+ self.attention(
252
+ q_x=self.ln_1(q_x),
253
+ k_x=k_x,
254
+ v_x=v_x,
255
+ attn_mask=attn_mask,
256
+ key_padding_mask=key_padding_mask,
257
+ )
258
+ )
259
+ x = x + self.ls_2(self.mlp(self.ln_2(x)))
260
+ return x
261
+
262
+
263
+ class CustomResidualAttentionBlock(nn.Module):
264
+ def __init__(
265
+ self,
266
+ d_model: int,
267
+ n_head: int,
268
+ mlp_ratio: float = 4.0,
269
+ ls_init_value: float = None,
270
+ act_layer: Type[nn.Module] = nn.GELU,
271
+ norm_layer: Type[nn.Module] = LayerNorm,
272
+ qk_norm: bool = False,
273
+ scale_cosine_attn: bool = False,
274
+ scale_heads: bool = False,
275
+ scale_attn_inner: bool = False,
276
+ scale_attn: bool = False,
277
+ scale_fc: bool = False,
278
+ batch_first: bool = True,
279
+ ):
280
+ super().__init__()
281
+ assert batch_first, "batch_first must be True for CustomResidualAttentionBlock"
282
+
283
+ self.ln_1 = norm_layer(d_model)
284
+ self.attn = Attention(
285
+ d_model,
286
+ n_head,
287
+ qk_norm=qk_norm,
288
+ scaled_cosine=scale_cosine_attn,
289
+ scale_heads=scale_heads,
290
+ inner_norm=scale_attn_inner,
291
+ norm_layer=norm_layer,
292
+ )
293
+ self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
294
+ self.ls_1 = (
295
+ LayerScale(d_model, ls_init_value)
296
+ if ls_init_value is not None
297
+ else nn.Identity()
298
+ )
299
+
300
+ self.ln_2 = norm_layer(d_model)
301
+ mlp_width = int(d_model * mlp_ratio)
302
+ self.mlp = nn.Sequential(
303
+ OrderedDict(
304
+ [
305
+ ("c_fc", nn.Linear(d_model, mlp_width)),
306
+ ("gelu", act_layer()),
307
+ ("ln", norm_layer(mlp_width) if scale_fc else nn.Identity()),
308
+ ("c_proj", nn.Linear(mlp_width, d_model)),
309
+ ]
310
+ )
311
+ )
312
+ self.ls_2 = (
313
+ LayerScale(d_model, ls_init_value)
314
+ if ls_init_value is not None
315
+ else nn.Identity()
316
+ )
317
+
318
+ def get_weight_dtype(self) -> torch.dtype:
319
+ if hasattr(self.mlp.c_fc, "int8_original_dtype"):
320
+ return self.mlp.c_fc.int8_original_dtype
321
+ return self.mlp.c_fc.weight.dtype
322
+
323
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
324
+ x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask)))
325
+ x = x + self.ls_2(self.mlp(self.ln_2(x)))
326
+ return x
327
+
328
+
329
+ class Transformer(nn.Module):
330
+ def __init__(
331
+ self,
332
+ width: int,
333
+ layers: int,
334
+ heads: int,
335
+ mlp_ratio: float = 4.0,
336
+ ls_init_value: float = None,
337
+ act_layer: Type[nn.Module] = nn.GELU,
338
+ norm_layer: Type[nn.Module] = LayerNorm,
339
+ batch_first: bool = True,
340
+ block_type: Optional[str] = None,
341
+ qk_norm: bool = False,
342
+ scaled_cosine_attn: bool = False,
343
+ scale_heads: bool = False,
344
+ scale_attn_inner: bool = False,
345
+ scale_attn: bool = False,
346
+ scale_fc: bool = False,
347
+ ):
348
+ super().__init__()
349
+ self.width = width
350
+ self.layers = layers
351
+ self.batch_first = batch_first
352
+ self.grad_checkpointing = False
353
+
354
+ if block_type is None:
355
+ if any(
356
+ [
357
+ qk_norm,
358
+ scaled_cosine_attn,
359
+ scale_heads,
360
+ scale_attn_inner,
361
+ scale_attn,
362
+ scale_fc,
363
+ ]
364
+ ):
365
+ block_type = "custom"
366
+ else:
367
+ block_type = "default"
368
+
369
+ if block_type == "custom":
370
+ self.resblocks = nn.ModuleList(
371
+ [
372
+ CustomResidualAttentionBlock(
373
+ width,
374
+ heads,
375
+ mlp_ratio,
376
+ ls_init_value=ls_init_value,
377
+ act_layer=act_layer,
378
+ norm_layer=norm_layer,
379
+ qk_norm=qk_norm,
380
+ scale_cosine_attn=scaled_cosine_attn,
381
+ scale_heads=scale_heads,
382
+ scale_attn_inner=scale_attn_inner,
383
+ scale_attn=scale_attn,
384
+ scale_fc=scale_fc,
385
+ batch_first=batch_first,
386
+ )
387
+ for _ in range(layers)
388
+ ]
389
+ )
390
+ else:
391
+ self.resblocks = nn.ModuleList(
392
+ [
393
+ ResidualAttentionBlock(
394
+ width,
395
+ heads,
396
+ mlp_ratio,
397
+ ls_init_value=ls_init_value,
398
+ act_layer=act_layer,
399
+ norm_layer=norm_layer,
400
+ batch_first=batch_first,
401
+ )
402
+ for _ in range(layers)
403
+ ]
404
+ )
405
+
406
+ def get_cast_dtype(self) -> torch.dtype:
407
+ return self.resblocks[0].get_weight_dtype()
408
+
409
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
410
+ if not self.batch_first:
411
+ x = x.transpose(0, 1).contiguous()
412
+
413
+ for r in self.resblocks:
414
+ if self.grad_checkpointing and not torch.jit.is_scripting():
415
+ x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False)
416
+ else:
417
+ x = r(x, attn_mask=attn_mask)
418
+
419
+ if not self.batch_first:
420
+ x = x.transpose(0, 1)
421
+ return x
422
+
423
+
424
+ def _expand_token(token, batch_size: int):
425
+ return token.view(1, 1, -1).expand(batch_size, -1, -1)
426
+
427
+
428
+ def text_global_pool(
429
+ x: torch.Tensor,
430
+ text: Optional[torch.Tensor] = None,
431
+ pool_type: str = "argmax",
432
+ eos_token_id: Optional[int] = None,
433
+ ) -> torch.Tensor:
434
+ if pool_type == "first":
435
+ pooled = x[:, 0]
436
+ elif pool_type == "last":
437
+ pooled = x[:, -1]
438
+ elif pool_type == "argmax":
439
+ assert text is not None
440
+ pooled = x[torch.arange(x.shape[0], device=x.device), text.argmax(dim=-1)]
441
+ elif pool_type == "eos":
442
+ assert text is not None
443
+ assert eos_token_id is not None
444
+ idx = (text == eos_token_id).int().argmax(dim=-1)
445
+ pooled = x[torch.arange(x.shape[0], device=x.device), idx]
446
+ else:
447
+ pooled = x
448
+
449
+ return pooled
450
+
451
+
452
+ class TextTransformer(nn.Module):
453
+ output_tokens: torch.jit.Final[bool]
454
+
455
+ def __init__(
456
+ self,
457
+ context_length: int = 77,
458
+ vocab_size: int = 49408,
459
+ width: int = 512,
460
+ heads: int = 8,
461
+ layers: int = 12,
462
+ mlp_ratio: float = 4.0,
463
+ ls_init_value: float = None,
464
+ output_dim: Optional[int] = 512,
465
+ embed_cls: bool = False,
466
+ no_causal_mask: bool = False,
467
+ use_pad_mask: bool = False,
468
+ correct_cls_mask: bool = False,
469
+ pad_id: int = 0,
470
+ eos_id: int = 2,
471
+ pool_type: str = "argmax",
472
+ proj_type: str = "linear",
473
+ proj_bias: bool = False,
474
+ act_layer: Type[nn.Module] = nn.GELU,
475
+ norm_layer: Type[nn.Module] = LayerNorm,
476
+ output_tokens: bool = False,
477
+ block_type: Optional[str] = None,
478
+ qk_norm: bool = False,
479
+ scaled_cosine_attn: bool = False,
480
+ scale_heads: bool = False,
481
+ scale_attn_inner: bool = False,
482
+ scale_attn: bool = False,
483
+ scale_fc: bool = False,
484
+ ):
485
+ super().__init__()
486
+ assert pool_type in ("first", "last", "argmax", "eos", "none")
487
+ self.output_tokens = output_tokens
488
+ self.num_pos = self.context_length = context_length
489
+ self.vocab_size = vocab_size
490
+ self.width = width
491
+ self.output_dim = output_dim
492
+ self.heads = heads
493
+ self.pad_id = pad_id
494
+ self.eos_id = eos_id
495
+ self.pool_type = pool_type
496
+ self.use_pad_mask = use_pad_mask and no_causal_mask
497
+ self.correct_cls_mask = correct_cls_mask
498
+
499
+ self.token_embedding = nn.Embedding(vocab_size, width)
500
+ if embed_cls:
501
+ self.cls_emb = nn.Parameter(torch.empty(width))
502
+ self.num_pos += 1
503
+ else:
504
+ self.cls_emb = None
505
+ self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
506
+ self.transformer = Transformer(
507
+ width=width,
508
+ layers=layers,
509
+ heads=heads,
510
+ mlp_ratio=mlp_ratio,
511
+ ls_init_value=ls_init_value,
512
+ act_layer=act_layer,
513
+ norm_layer=norm_layer,
514
+ block_type=block_type,
515
+ qk_norm=qk_norm,
516
+ scaled_cosine_attn=scaled_cosine_attn,
517
+ scale_heads=scale_heads,
518
+ scale_attn_inner=scale_attn_inner,
519
+ scale_attn=scale_attn,
520
+ scale_fc=scale_fc,
521
+ )
522
+ self.ln_final = norm_layer(width)
523
+
524
+ if no_causal_mask:
525
+ self.attn_mask = None
526
+ else:
527
+ self.register_buffer(
528
+ "attn_mask", self.build_causal_mask(), persistent=False
529
+ )
530
+
531
+ if proj_type == "none" or not output_dim:
532
+ self.text_projection = None
533
+ else:
534
+ if proj_bias:
535
+ self.text_projection = nn.Linear(width, output_dim)
536
+ else:
537
+ self.text_projection = nn.Parameter(torch.empty(width, output_dim))
538
+
539
+ self.init_parameters()
540
+
541
+ def init_parameters(self):
542
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
543
+ nn.init.normal_(self.positional_embedding, std=0.01)
544
+ if self.cls_emb is not None:
545
+ nn.init.normal_(self.cls_emb, std=0.01)
546
+
547
+ proj_std = (self.transformer.width**-0.5) * (
548
+ (2 * self.transformer.layers) ** -0.5
549
+ )
550
+ attn_std = self.transformer.width**-0.5
551
+ fc_std = (2 * self.transformer.width) ** -0.5
552
+ for block in self.transformer.resblocks:
553
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
554
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
555
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
556
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
557
+
558
+ if self.text_projection is not None:
559
+ if isinstance(self.text_projection, nn.Linear):
560
+ nn.init.normal_(
561
+ self.text_projection.weight, std=self.transformer.width**-0.5
562
+ )
563
+ if self.text_projection.bias is not None:
564
+ nn.init.zeros_(self.text_projection.bias)
565
+ else:
566
+ nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5)
567
+
568
+ def build_causal_mask(self):
569
+ mask = torch.empty(self.num_pos, self.num_pos)
570
+ mask.fill_(float("-inf"))
571
+ mask.triu_(1)
572
+ return mask
573
+
574
+ def _build_additive_mask(self, text, seq_len, dtype):
575
+ valid = text != self.pad_id
576
+ if self.cls_emb is not None:
577
+ cls_valid = valid.new_ones(valid.size(0), 1)
578
+ valid = torch.cat(
579
+ [valid, cls_valid] if self.correct_cls_mask else [cls_valid, valid], 1
580
+ )
581
+ key_mask = valid.unsqueeze(1).expand(-1, seq_len, -1)
582
+ additive = torch.zeros_like(key_mask, dtype=dtype)
583
+ additive.masked_fill_(~key_mask, float("-inf"))
584
+ additive = additive.repeat_interleave(self.heads, 0)
585
+ return additive
586
+
587
+ def _embeds(self, text):
588
+ cast_dtype = self.transformer.get_cast_dtype()
589
+ B, seq_len = text.shape
590
+ x = self.token_embedding(text).to(cast_dtype)
591
+ if self.cls_emb is not None:
592
+ x = torch.cat([x, _expand_token(self.cls_emb, x.size(0))], 1)
593
+ seq_len += 1
594
+ attn_mask = self.attn_mask
595
+ if self.use_pad_mask or self.cls_emb is not None:
596
+ add_mask = self._build_additive_mask(text, seq_len, x.dtype)
597
+ if attn_mask is not None:
598
+ attn_mask = attn_mask[:seq_len, :seq_len].unsqueeze(0) + add_mask
599
+ else:
600
+ attn_mask = add_mask
601
+ x = x + self.positional_embedding[:seq_len].to(cast_dtype)
602
+ return x, attn_mask
603
+
604
+ def forward(self, text):
605
+ x, attn_mask = self._embeds(text)
606
+ x = self.transformer(x, attn_mask=attn_mask)
607
+ if self.cls_emb is not None:
608
+ pooled = text_global_pool(x, pool_type="last")
609
+ pooled = self.ln_final(pooled)
610
+ tokens = x[:, :-1]
611
+ else:
612
+ x = self.ln_final(x)
613
+ pooled = text_global_pool(
614
+ x,
615
+ text,
616
+ pool_type=self.pool_type,
617
+ eos_token_id=getattr(self, "eos_id", None),
618
+ )
619
+ tokens = x
620
+ if self.text_projection is not None:
621
+ if isinstance(self.text_projection, nn.Linear):
622
+ pooled = self.text_projection(pooled)
623
+ else:
624
+ pooled = pooled @ self.text_projection
625
+ if self.output_tokens:
626
+ return pooled, tokens
627
+ return pooled
raon_vision_encoder/utils.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Originally from OpenCLIP (https://github.com/mlfoundations/open_clip)
2
+
3
+ import collections.abc
4
+ from itertools import repeat
5
+
6
+
7
+ def _ntuple(n):
8
+ def parse(x):
9
+ if isinstance(x, collections.abc.Iterable):
10
+ return x
11
+ return tuple(repeat(x, n))
12
+
13
+ return parse
14
+
15
+
16
+ to_2tuple = _ntuple(2)