vpraveen-nv commited on
Commit
6cf9f8d
·
verified ·
1 Parent(s): b9a6f88

Upload Cosmos-Embed1-448p anomaly-detection fine-tune (LoRA, Vad-Reasoning)

Browse files
config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CosmosEmbed1"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_embed1.CosmosEmbed1Config",
7
+ "AutoProcessor": "preprocessing_embed1.CosmosEmbed1Processor",
8
+ "AutoModel": "modeling_embed1.CosmosEmbed1"
9
+ },
10
+ "model_type": "cosmos-embed1",
11
+ "embed_dim": 768,
12
+ "num_query_tokens": 32,
13
+ "max_txt_len": 128,
14
+ "num_video_frames": 8,
15
+ "resolution": 448,
16
+ "temporal_encoding_type": "neighboring_token_propagation",
17
+ "vocab_size": 30523,
18
+ "transformer_engine": false,
19
+ "use_fp8": false
20
+ }
configuration_embed1.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Configuration class for Cosmos-Embed1."""
17
+
18
+ from typing import Any, Literal, Tuple, Union
19
+
20
+ from transformers import AutoConfig, PretrainedConfig
21
+
22
+
23
+ class CosmosEmbed1Config(PretrainedConfig):
24
+ model_type = "cosmos-embed1"
25
+
26
+ def __init__(
27
+ self,
28
+ embed_dim: int = 768,
29
+ num_query_tokens: int = 32,
30
+ max_txt_len: int = 128,
31
+ num_video_frames: int = 8,
32
+ temporal_encoding_type: Literal[
33
+ "neighboring_token_propagation", "temporal_parameter"
34
+ ] = "neighboring_token_propagation",
35
+ resolution: Union[int, Tuple[int, int]] = 224,
36
+ vocab_size: int = 30523,
37
+ transformer_engine: bool = False,
38
+ use_fp8: bool = False,
39
+ **kwargs: Any,
40
+ ) -> None:
41
+ """Configuration for `CosmosEmbed1Config`.
42
+
43
+ Args:
44
+ embed_dim (int): the dimension of extracted text-visual embeddings.
45
+ num_query_tokens (int): number of learnable query tokens.
46
+ max_txt_len (int): max length of text token sequences before truncation.
47
+ num_video_frames (int): number of input video frames.
48
+ temporal_encoding_type (str): temporal encoding module type.
49
+ resolution (Union[int, Tuple[int, int]]): input video frame resolution.
50
+ Can be an integer for square images (height=width) or a tuple of (height, width) for non-square.
51
+ vocab_size (int): vocab size for text tokenizer.
52
+ The default is from `bert-base-uncased` with an extra [DEC] token.
53
+ transformer_engine (bool): whether to use TransformerEngine for acceleration.
54
+ use_fp8 (bool): whether to use FP8 precision (requires transformer_engine=True).
55
+ """
56
+ super().__init__(**kwargs)
57
+
58
+ self.embed_dim = embed_dim
59
+ self.num_query_tokens = num_query_tokens
60
+ self.max_txt_len = max_txt_len
61
+ self.num_video_frames = num_video_frames
62
+ self.temporal_encoding_type = temporal_encoding_type
63
+ self.resolution = resolution
64
+ self.vocab_size = vocab_size
65
+ self.transformer_engine = transformer_engine
66
+ self.use_fp8 = use_fp8
67
+
68
+
69
+ AutoConfig.register("cosmos-embed1", CosmosEmbed1Config)
export_config.yaml ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb:
2
+ enable: false
3
+ project: cosmos_embed1
4
+ group: ''
5
+ name: ''
6
+ tags: []
7
+ save_code: false
8
+ api_key: ''
9
+ model:
10
+ network:
11
+ visual_encoder:
12
+ type: eva_vit_g
13
+ img_size: 224
14
+ pretrained: false
15
+ use_fp8: false
16
+ transformer_engine: false
17
+ checkpoint_activations: false
18
+ checkpoint_attention: false
19
+ embed_dim: 768
20
+ num_query_tokens: 32
21
+ max_txt_len: 128
22
+ num_video_frames: 8
23
+ spatial_resolution:
24
+ - 448
25
+ - 448
26
+ temporal_encoding_type: neighboring_token_propagation
27
+ contrastive_type: clip
28
+ qformer_pretrain_ckpt: null
29
+ query_pooling_type: avg
30
+ pretrained_text_encoder: false
31
+ pretrained_visual_encoder: false
32
+ num_heldout_frames: 0
33
+ pretrained_model_path: null
34
+ pretrained_model_strict: true
35
+ precision: fp32
36
+ input_hw:
37
+ - 512
38
+ - 512
39
+ fsdp:
40
+ enabled: false
41
+ shard_size: null
42
+ replica_size: null
43
+ fsdp_shard_size: 8
44
+ dataset:
45
+ train_dataset:
46
+ dataset_type: mock
47
+ metadata: null
48
+ data_root: null
49
+ num_video_frames: 8
50
+ resolution:
51
+ - 224
52
+ - 224
53
+ batch_size: 4
54
+ workers: 4
55
+ drop_last: true
56
+ prefetch_factor: 2
57
+ pin_memory: true
58
+ split: null
59
+ random_caption: false
60
+ path_prefix_mapping: {}
61
+ skip_missing_files: true
62
+ caption_field: anomaly_type
63
+ mp4_urls: null
64
+ caption_to_label: {}
65
+ chunk_size_sec: 5.0
66
+ val_dataset:
67
+ dataset_type: mock
68
+ metadata: null
69
+ data_root: null
70
+ num_video_frames: 8
71
+ resolution:
72
+ - 224
73
+ - 224
74
+ batch_size: 4
75
+ workers: 4
76
+ drop_last: true
77
+ prefetch_factor: 2
78
+ pin_memory: true
79
+ split: null
80
+ random_caption: false
81
+ path_prefix_mapping: {}
82
+ skip_missing_files: true
83
+ caption_field: anomaly_type
84
+ mp4_urls: null
85
+ caption_to_label: {}
86
+ chunk_size_sec: 5.0
87
+ test_dataset:
88
+ dataset_type: mock
89
+ metadata: null
90
+ data_root: null
91
+ num_video_frames: 8
92
+ resolution:
93
+ - 224
94
+ - 224
95
+ batch_size: 4
96
+ workers: 4
97
+ drop_last: true
98
+ prefetch_factor: 2
99
+ pin_memory: true
100
+ split: null
101
+ random_caption: false
102
+ path_prefix_mapping: {}
103
+ skip_missing_files: true
104
+ caption_field: anomaly_type
105
+ mp4_urls: null
106
+ caption_to_label: {}
107
+ chunk_size_sec: 5.0
108
+ inference_dataset:
109
+ dataset_type: mock
110
+ metadata: null
111
+ data_root: null
112
+ num_video_frames: 8
113
+ resolution:
114
+ - 224
115
+ - 224
116
+ batch_size: 4
117
+ workers: 4
118
+ drop_last: true
119
+ prefetch_factor: 2
120
+ pin_memory: true
121
+ split: null
122
+ random_caption: false
123
+ path_prefix_mapping: {}
124
+ skip_missing_files: true
125
+ caption_field: anomaly_type
126
+ mp4_urls: null
127
+ caption_to_label: {}
128
+ chunk_size_sec: 5.0
129
+ train:
130
+ optim:
131
+ optim: adamw
132
+ lr: 1.0e-05
133
+ weight_decay: 1.0e-05
134
+ betas:
135
+ - 0.9
136
+ - 0.98
137
+ warmup_steps: 1000
138
+ policy: cosine
139
+ lr_decay_iters: 50000
140
+ loss_weights:
141
+ contrastive_loss: 1.0
142
+ captioning_loss: 1.0
143
+ matching_loss: 1.0
144
+ lora:
145
+ enabled: false
146
+ lora_rank: 8
147
+ lora_alpha: 16
148
+ lora_dropout: 0.1
149
+ bias: none
150
+ use_rslora: false
151
+ use_dora: false
152
+ target_modules: []
153
+ modules_to_save: []
154
+ seed: 1234
155
+ max_iter: 50000
156
+ num_nodes: 1
157
+ num_gpus: 1
158
+ gpu_ids:
159
+ - 0
160
+ validation_iter: 1000
161
+ checkpoint_iter: 1000
162
+ clip_grad_norm: 0.0
163
+ precision: bf16
164
+ resume_training_checkpoint_path: null
165
+ callbacks:
166
+ wandb: {}
167
+ clamp_logit_scale: {}
168
+ logit_parameters_monitor: {}
169
+ iter_speed:
170
+ every_n: 50
171
+ save_s3: false
172
+ gradient_clip:
173
+ clip_norm: 3.0
174
+ grad_norm_monitor:
175
+ every_n: 500
176
+ verbose: false
177
+ spectral_norm_monitor:
178
+ every_n: 1000
179
+ verbose: true
180
+ ema: {}
181
+ log_losses:
182
+ every_n: 50
183
+ verbose: true
184
+ text_frames_visualizer:
185
+ every_n: 500
186
+ pca_feature_map_visualizer:
187
+ every_n: 500
188
+ max_val_iter: null
189
+ freeze_visual_encoder: true
190
+ use_captioning_loss: true
191
+ use_text_matching_loss: false
192
+ ema:
193
+ enabled: false
194
+ beta: 0.9999
195
+ spectral_reparam: false
196
+ damp:
197
+ enabled: false
198
+ beta: 0.1
199
+ mode: const
200
+ load_training_state: false
201
+ strict_resume: false
202
+ evaluate:
203
+ checkpoint: null
204
+ max_val_batches: -1
205
+ num_gpus: 1
206
+ callbacks:
207
+ topk_classification: true
208
+ embedding_visualization: false
209
+ top_k_values:
210
+ - 1
211
+ - 3
212
+ - 5
213
+ - 10
214
+ max_eval_samples: 2000
215
+ load_dataset_pkl: null
216
+ save_dataset_pkl: null
217
+ inference:
218
+ checkpoint: null
219
+ query:
220
+ input_videos: []
221
+ input_texts: []
222
+ num_gpus: 1
223
+ k: 5
224
+ load_dataset_pkl: null
225
+ save_dataset_pkl: null
226
+ export:
227
+ checkpoint: /workspace/alicli/experiments/cosmos_embed1_finetune/finetune/finetune_448p_hf/train/checkpoints/iter_000006000.pt
228
+ onnx_file: null
229
+ mode: huggingface
230
+ opset_version: 17
231
+ batch_size: 1
232
+ on_cpu: true
233
+ verbose: false
234
+ simplify: false
235
+ hf_output_dir: /workspace/alicli/experiments/cosmos_embed1_finetune/finetune/finetune_448p_hf/cosmos_embed1_448p_6000iter_hf
236
+ results_dir: /workspace/alicli/experiments/cosmos_embed1_finetune/finetune/finetune_448p_hf
237
+ encryption_key: null
238
+ model_name: cosmos_embed1
model-00001-of-00010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe3ea01cd802f56b560fc34b1a52f417821eb534d694f2e36c8c4f4385cc12cd
3
+ size 517488448
model-00002-of-00010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8f7f924243c33217a764f817c5d27ba8ffff5f9c15a2e42abd788468b3eb2cf
3
+ size 493842736
model-00003-of-00010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62c1e0b262e862efac6c2aa083c25a897363a2bd116aefa44dc259b810513878
3
+ size 505010128
model-00004-of-00010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a6bb2c6eb056cd177ca4c4e4b7b5fe2b9c4e47b61659c13da33f03abfe522851
3
+ size 505010136
model-00005-of-00010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:277dd3fbe9b5f8197c117244edff18bc220a6762294d6d5e9bfe5dd2380e0dc2
3
+ size 505010120
model-00006-of-00010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98e45d9d21e0fc0af325908c5654825f9d3aa5384e1754a5bdee1c3ff4b17c0b
3
+ size 505010136
model-00007-of-00010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7d65fc3291c03329047b0b6214353fb107a079dcef53c3c7d05ef39ba19b663
3
+ size 505010120
model-00008-of-00010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4dcf217ab3159b2ec92465e0d266d3bd18fff55c5dd2a6264edd6f13ab165fd1
3
+ size 505010136
model-00009-of-00010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e028831d5bef332882391436d7ebb104fcb95dcdf75f9d56b47e96623c10b144
3
+ size 505010088
model-00010-of-00010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:15ef5ea4966b056b4a5070307ff37ff438abad5f869c77a688f0660401a942c1
3
+ size 245726200
model.safetensors.index.json ADDED
@@ -0,0 +1,827 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 4792032744
4
+ },
5
+ "weight_map": {
6
+ "itm_proj.bias": "model-00001-of-00010.safetensors",
7
+ "itm_proj.weight": "model-00001-of-00010.safetensors",
8
+ "ln_vision.bias": "model-00001-of-00010.safetensors",
9
+ "ln_vision.weight": "model-00001-of-00010.safetensors",
10
+ "logit_bias": "model-00001-of-00010.safetensors",
11
+ "logit_scale": "model-00001-of-00010.safetensors",
12
+ "qformer.bert.embeddings.LayerNorm.bias": "model-00001-of-00010.safetensors",
13
+ "qformer.bert.embeddings.LayerNorm.weight": "model-00001-of-00010.safetensors",
14
+ "qformer.bert.embeddings.position_embeddings.weight": "model-00001-of-00010.safetensors",
15
+ "qformer.bert.embeddings.position_ids": "model-00001-of-00010.safetensors",
16
+ "qformer.bert.embeddings.word_embeddings.weight": "model-00001-of-00010.safetensors",
17
+ "qformer.bert.encoder.layer.0.attention.output.LayerNorm.bias": "model-00001-of-00010.safetensors",
18
+ "qformer.bert.encoder.layer.0.attention.output.LayerNorm.weight": "model-00001-of-00010.safetensors",
19
+ "qformer.bert.encoder.layer.0.attention.output.dense.bias": "model-00001-of-00010.safetensors",
20
+ "qformer.bert.encoder.layer.0.attention.output.dense.weight": "model-00001-of-00010.safetensors",
21
+ "qformer.bert.encoder.layer.0.attention.self.key.bias": "model-00001-of-00010.safetensors",
22
+ "qformer.bert.encoder.layer.0.attention.self.key.weight": "model-00001-of-00010.safetensors",
23
+ "qformer.bert.encoder.layer.0.attention.self.query.bias": "model-00001-of-00010.safetensors",
24
+ "qformer.bert.encoder.layer.0.attention.self.query.weight": "model-00001-of-00010.safetensors",
25
+ "qformer.bert.encoder.layer.0.attention.self.value.bias": "model-00001-of-00010.safetensors",
26
+ "qformer.bert.encoder.layer.0.attention.self.value.weight": "model-00001-of-00010.safetensors",
27
+ "qformer.bert.encoder.layer.0.crossattention.output.LayerNorm.bias": "model-00001-of-00010.safetensors",
28
+ "qformer.bert.encoder.layer.0.crossattention.output.LayerNorm.weight": "model-00001-of-00010.safetensors",
29
+ "qformer.bert.encoder.layer.0.crossattention.output.dense.bias": "model-00001-of-00010.safetensors",
30
+ "qformer.bert.encoder.layer.0.crossattention.output.dense.weight": "model-00001-of-00010.safetensors",
31
+ "qformer.bert.encoder.layer.0.crossattention.self.key.bias": "model-00001-of-00010.safetensors",
32
+ "qformer.bert.encoder.layer.0.crossattention.self.key.weight": "model-00001-of-00010.safetensors",
33
+ "qformer.bert.encoder.layer.0.crossattention.self.query.bias": "model-00001-of-00010.safetensors",
34
+ "qformer.bert.encoder.layer.0.crossattention.self.query.weight": "model-00001-of-00010.safetensors",
35
+ "qformer.bert.encoder.layer.0.crossattention.self.value.bias": "model-00001-of-00010.safetensors",
36
+ "qformer.bert.encoder.layer.0.crossattention.self.value.weight": "model-00001-of-00010.safetensors",
37
+ "qformer.bert.encoder.layer.0.intermediate.dense.bias": "model-00001-of-00010.safetensors",
38
+ "qformer.bert.encoder.layer.0.intermediate.dense.weight": "model-00001-of-00010.safetensors",
39
+ "qformer.bert.encoder.layer.0.intermediate_query.dense.bias": "model-00001-of-00010.safetensors",
40
+ "qformer.bert.encoder.layer.0.intermediate_query.dense.weight": "model-00001-of-00010.safetensors",
41
+ "qformer.bert.encoder.layer.0.output.LayerNorm.bias": "model-00001-of-00010.safetensors",
42
+ "qformer.bert.encoder.layer.0.output.LayerNorm.weight": "model-00001-of-00010.safetensors",
43
+ "qformer.bert.encoder.layer.0.output.dense.bias": "model-00001-of-00010.safetensors",
44
+ "qformer.bert.encoder.layer.0.output.dense.weight": "model-00001-of-00010.safetensors",
45
+ "qformer.bert.encoder.layer.0.output_query.LayerNorm.bias": "model-00001-of-00010.safetensors",
46
+ "qformer.bert.encoder.layer.0.output_query.LayerNorm.weight": "model-00001-of-00010.safetensors",
47
+ "qformer.bert.encoder.layer.0.output_query.dense.bias": "model-00001-of-00010.safetensors",
48
+ "qformer.bert.encoder.layer.0.output_query.dense.weight": "model-00001-of-00010.safetensors",
49
+ "qformer.bert.encoder.layer.1.attention.output.LayerNorm.bias": "model-00001-of-00010.safetensors",
50
+ "qformer.bert.encoder.layer.1.attention.output.LayerNorm.weight": "model-00001-of-00010.safetensors",
51
+ "qformer.bert.encoder.layer.1.attention.output.dense.bias": "model-00001-of-00010.safetensors",
52
+ "qformer.bert.encoder.layer.1.attention.output.dense.weight": "model-00001-of-00010.safetensors",
53
+ "qformer.bert.encoder.layer.1.attention.self.key.bias": "model-00001-of-00010.safetensors",
54
+ "qformer.bert.encoder.layer.1.attention.self.key.weight": "model-00001-of-00010.safetensors",
55
+ "qformer.bert.encoder.layer.1.attention.self.query.bias": "model-00001-of-00010.safetensors",
56
+ "qformer.bert.encoder.layer.1.attention.self.query.weight": "model-00001-of-00010.safetensors",
57
+ "qformer.bert.encoder.layer.1.attention.self.value.bias": "model-00001-of-00010.safetensors",
58
+ "qformer.bert.encoder.layer.1.attention.self.value.weight": "model-00001-of-00010.safetensors",
59
+ "qformer.bert.encoder.layer.1.intermediate.dense.bias": "model-00001-of-00010.safetensors",
60
+ "qformer.bert.encoder.layer.1.intermediate.dense.weight": "model-00001-of-00010.safetensors",
61
+ "qformer.bert.encoder.layer.1.intermediate_query.dense.bias": "model-00001-of-00010.safetensors",
62
+ "qformer.bert.encoder.layer.1.intermediate_query.dense.weight": "model-00001-of-00010.safetensors",
63
+ "qformer.bert.encoder.layer.1.output.LayerNorm.bias": "model-00001-of-00010.safetensors",
64
+ "qformer.bert.encoder.layer.1.output.LayerNorm.weight": "model-00001-of-00010.safetensors",
65
+ "qformer.bert.encoder.layer.1.output.dense.bias": "model-00001-of-00010.safetensors",
66
+ "qformer.bert.encoder.layer.1.output.dense.weight": "model-00001-of-00010.safetensors",
67
+ "qformer.bert.encoder.layer.1.output_query.LayerNorm.bias": "model-00001-of-00010.safetensors",
68
+ "qformer.bert.encoder.layer.1.output_query.LayerNorm.weight": "model-00001-of-00010.safetensors",
69
+ "qformer.bert.encoder.layer.1.output_query.dense.bias": "model-00001-of-00010.safetensors",
70
+ "qformer.bert.encoder.layer.1.output_query.dense.weight": "model-00001-of-00010.safetensors",
71
+ "qformer.bert.encoder.layer.10.attention.output.LayerNorm.bias": "model-00001-of-00010.safetensors",
72
+ "qformer.bert.encoder.layer.10.attention.output.LayerNorm.weight": "model-00001-of-00010.safetensors",
73
+ "qformer.bert.encoder.layer.10.attention.output.dense.bias": "model-00001-of-00010.safetensors",
74
+ "qformer.bert.encoder.layer.10.attention.output.dense.weight": "model-00001-of-00010.safetensors",
75
+ "qformer.bert.encoder.layer.10.attention.self.key.bias": "model-00001-of-00010.safetensors",
76
+ "qformer.bert.encoder.layer.10.attention.self.key.weight": "model-00001-of-00010.safetensors",
77
+ "qformer.bert.encoder.layer.10.attention.self.query.bias": "model-00001-of-00010.safetensors",
78
+ "qformer.bert.encoder.layer.10.attention.self.query.weight": "model-00001-of-00010.safetensors",
79
+ "qformer.bert.encoder.layer.10.attention.self.value.bias": "model-00001-of-00010.safetensors",
80
+ "qformer.bert.encoder.layer.10.attention.self.value.weight": "model-00001-of-00010.safetensors",
81
+ "qformer.bert.encoder.layer.10.crossattention.output.LayerNorm.bias": "model-00001-of-00010.safetensors",
82
+ "qformer.bert.encoder.layer.10.crossattention.output.LayerNorm.weight": "model-00001-of-00010.safetensors",
83
+ "qformer.bert.encoder.layer.10.crossattention.output.dense.bias": "model-00001-of-00010.safetensors",
84
+ "qformer.bert.encoder.layer.10.crossattention.output.dense.weight": "model-00001-of-00010.safetensors",
85
+ "qformer.bert.encoder.layer.10.crossattention.self.key.bias": "model-00001-of-00010.safetensors",
86
+ "qformer.bert.encoder.layer.10.crossattention.self.key.weight": "model-00001-of-00010.safetensors",
87
+ "qformer.bert.encoder.layer.10.crossattention.self.query.bias": "model-00001-of-00010.safetensors",
88
+ "qformer.bert.encoder.layer.10.crossattention.self.query.weight": "model-00001-of-00010.safetensors",
89
+ "qformer.bert.encoder.layer.10.crossattention.self.value.bias": "model-00001-of-00010.safetensors",
90
+ "qformer.bert.encoder.layer.10.crossattention.self.value.weight": "model-00001-of-00010.safetensors",
91
+ "qformer.bert.encoder.layer.10.intermediate.dense.bias": "model-00001-of-00010.safetensors",
92
+ "qformer.bert.encoder.layer.10.intermediate.dense.weight": "model-00001-of-00010.safetensors",
93
+ "qformer.bert.encoder.layer.10.intermediate_query.dense.bias": "model-00001-of-00010.safetensors",
94
+ "qformer.bert.encoder.layer.10.intermediate_query.dense.weight": "model-00001-of-00010.safetensors",
95
+ "qformer.bert.encoder.layer.10.output.LayerNorm.bias": "model-00001-of-00010.safetensors",
96
+ "qformer.bert.encoder.layer.10.output.LayerNorm.weight": "model-00001-of-00010.safetensors",
97
+ "qformer.bert.encoder.layer.10.output.dense.bias": "model-00001-of-00010.safetensors",
98
+ "qformer.bert.encoder.layer.10.output.dense.weight": "model-00001-of-00010.safetensors",
99
+ "qformer.bert.encoder.layer.10.output_query.LayerNorm.bias": "model-00001-of-00010.safetensors",
100
+ "qformer.bert.encoder.layer.10.output_query.LayerNorm.weight": "model-00001-of-00010.safetensors",
101
+ "qformer.bert.encoder.layer.10.output_query.dense.bias": "model-00001-of-00010.safetensors",
102
+ "qformer.bert.encoder.layer.10.output_query.dense.weight": "model-00001-of-00010.safetensors",
103
+ "qformer.bert.encoder.layer.11.attention.output.LayerNorm.bias": "model-00001-of-00010.safetensors",
104
+ "qformer.bert.encoder.layer.11.attention.output.LayerNorm.weight": "model-00001-of-00010.safetensors",
105
+ "qformer.bert.encoder.layer.11.attention.output.dense.bias": "model-00001-of-00010.safetensors",
106
+ "qformer.bert.encoder.layer.11.attention.output.dense.weight": "model-00001-of-00010.safetensors",
107
+ "qformer.bert.encoder.layer.11.attention.self.key.bias": "model-00001-of-00010.safetensors",
108
+ "qformer.bert.encoder.layer.11.attention.self.key.weight": "model-00001-of-00010.safetensors",
109
+ "qformer.bert.encoder.layer.11.attention.self.query.bias": "model-00001-of-00010.safetensors",
110
+ "qformer.bert.encoder.layer.11.attention.self.query.weight": "model-00001-of-00010.safetensors",
111
+ "qformer.bert.encoder.layer.11.attention.self.value.bias": "model-00001-of-00010.safetensors",
112
+ "qformer.bert.encoder.layer.11.attention.self.value.weight": "model-00001-of-00010.safetensors",
113
+ "qformer.bert.encoder.layer.11.intermediate.dense.bias": "model-00001-of-00010.safetensors",
114
+ "qformer.bert.encoder.layer.11.intermediate.dense.weight": "model-00001-of-00010.safetensors",
115
+ "qformer.bert.encoder.layer.11.intermediate_query.dense.bias": "model-00001-of-00010.safetensors",
116
+ "qformer.bert.encoder.layer.11.intermediate_query.dense.weight": "model-00001-of-00010.safetensors",
117
+ "qformer.bert.encoder.layer.11.output.LayerNorm.bias": "model-00001-of-00010.safetensors",
118
+ "qformer.bert.encoder.layer.11.output.LayerNorm.weight": "model-00001-of-00010.safetensors",
119
+ "qformer.bert.encoder.layer.11.output.dense.bias": "model-00001-of-00010.safetensors",
120
+ "qformer.bert.encoder.layer.11.output.dense.weight": "model-00001-of-00010.safetensors",
121
+ "qformer.bert.encoder.layer.11.output_query.LayerNorm.bias": "model-00001-of-00010.safetensors",
122
+ "qformer.bert.encoder.layer.11.output_query.LayerNorm.weight": "model-00001-of-00010.safetensors",
123
+ "qformer.bert.encoder.layer.11.output_query.dense.bias": "model-00001-of-00010.safetensors",
124
+ "qformer.bert.encoder.layer.11.output_query.dense.weight": "model-00001-of-00010.safetensors",
125
+ "qformer.bert.encoder.layer.2.attention.output.LayerNorm.bias": "model-00001-of-00010.safetensors",
126
+ "qformer.bert.encoder.layer.2.attention.output.LayerNorm.weight": "model-00001-of-00010.safetensors",
127
+ "qformer.bert.encoder.layer.2.attention.output.dense.bias": "model-00001-of-00010.safetensors",
128
+ "qformer.bert.encoder.layer.2.attention.output.dense.weight": "model-00001-of-00010.safetensors",
129
+ "qformer.bert.encoder.layer.2.attention.self.key.bias": "model-00001-of-00010.safetensors",
130
+ "qformer.bert.encoder.layer.2.attention.self.key.weight": "model-00001-of-00010.safetensors",
131
+ "qformer.bert.encoder.layer.2.attention.self.query.bias": "model-00001-of-00010.safetensors",
132
+ "qformer.bert.encoder.layer.2.attention.self.query.weight": "model-00001-of-00010.safetensors",
133
+ "qformer.bert.encoder.layer.2.attention.self.value.bias": "model-00001-of-00010.safetensors",
134
+ "qformer.bert.encoder.layer.2.attention.self.value.weight": "model-00001-of-00010.safetensors",
135
+ "qformer.bert.encoder.layer.2.crossattention.output.LayerNorm.bias": "model-00001-of-00010.safetensors",
136
+ "qformer.bert.encoder.layer.2.crossattention.output.LayerNorm.weight": "model-00001-of-00010.safetensors",
137
+ "qformer.bert.encoder.layer.2.crossattention.output.dense.bias": "model-00001-of-00010.safetensors",
138
+ "qformer.bert.encoder.layer.2.crossattention.output.dense.weight": "model-00001-of-00010.safetensors",
139
+ "qformer.bert.encoder.layer.2.crossattention.self.key.bias": "model-00001-of-00010.safetensors",
140
+ "qformer.bert.encoder.layer.2.crossattention.self.key.weight": "model-00001-of-00010.safetensors",
141
+ "qformer.bert.encoder.layer.2.crossattention.self.query.bias": "model-00001-of-00010.safetensors",
142
+ "qformer.bert.encoder.layer.2.crossattention.self.query.weight": "model-00001-of-00010.safetensors",
143
+ "qformer.bert.encoder.layer.2.crossattention.self.value.bias": "model-00001-of-00010.safetensors",
144
+ "qformer.bert.encoder.layer.2.crossattention.self.value.weight": "model-00001-of-00010.safetensors",
145
+ "qformer.bert.encoder.layer.2.intermediate.dense.bias": "model-00001-of-00010.safetensors",
146
+ "qformer.bert.encoder.layer.2.intermediate.dense.weight": "model-00001-of-00010.safetensors",
147
+ "qformer.bert.encoder.layer.2.intermediate_query.dense.bias": "model-00001-of-00010.safetensors",
148
+ "qformer.bert.encoder.layer.2.intermediate_query.dense.weight": "model-00001-of-00010.safetensors",
149
+ "qformer.bert.encoder.layer.2.output.LayerNorm.bias": "model-00001-of-00010.safetensors",
150
+ "qformer.bert.encoder.layer.2.output.LayerNorm.weight": "model-00001-of-00010.safetensors",
151
+ "qformer.bert.encoder.layer.2.output.dense.bias": "model-00001-of-00010.safetensors",
152
+ "qformer.bert.encoder.layer.2.output.dense.weight": "model-00001-of-00010.safetensors",
153
+ "qformer.bert.encoder.layer.2.output_query.LayerNorm.bias": "model-00001-of-00010.safetensors",
154
+ "qformer.bert.encoder.layer.2.output_query.LayerNorm.weight": "model-00001-of-00010.safetensors",
155
+ "qformer.bert.encoder.layer.2.output_query.dense.bias": "model-00001-of-00010.safetensors",
156
+ "qformer.bert.encoder.layer.2.output_query.dense.weight": "model-00001-of-00010.safetensors",
157
+ "qformer.bert.encoder.layer.3.attention.output.LayerNorm.bias": "model-00001-of-00010.safetensors",
158
+ "qformer.bert.encoder.layer.3.attention.output.LayerNorm.weight": "model-00001-of-00010.safetensors",
159
+ "qformer.bert.encoder.layer.3.attention.output.dense.bias": "model-00001-of-00010.safetensors",
160
+ "qformer.bert.encoder.layer.3.attention.output.dense.weight": "model-00001-of-00010.safetensors",
161
+ "qformer.bert.encoder.layer.3.attention.self.key.bias": "model-00001-of-00010.safetensors",
162
+ "qformer.bert.encoder.layer.3.attention.self.key.weight": "model-00001-of-00010.safetensors",
163
+ "qformer.bert.encoder.layer.3.attention.self.query.bias": "model-00001-of-00010.safetensors",
164
+ "qformer.bert.encoder.layer.3.attention.self.query.weight": "model-00001-of-00010.safetensors",
165
+ "qformer.bert.encoder.layer.3.attention.self.value.bias": "model-00001-of-00010.safetensors",
166
+ "qformer.bert.encoder.layer.3.attention.self.value.weight": "model-00001-of-00010.safetensors",
167
+ "qformer.bert.encoder.layer.3.intermediate.dense.bias": "model-00001-of-00010.safetensors",
168
+ "qformer.bert.encoder.layer.3.intermediate.dense.weight": "model-00001-of-00010.safetensors",
169
+ "qformer.bert.encoder.layer.3.intermediate_query.dense.bias": "model-00001-of-00010.safetensors",
170
+ "qformer.bert.encoder.layer.3.intermediate_query.dense.weight": "model-00001-of-00010.safetensors",
171
+ "qformer.bert.encoder.layer.3.output.LayerNorm.bias": "model-00001-of-00010.safetensors",
172
+ "qformer.bert.encoder.layer.3.output.LayerNorm.weight": "model-00001-of-00010.safetensors",
173
+ "qformer.bert.encoder.layer.3.output.dense.bias": "model-00001-of-00010.safetensors",
174
+ "qformer.bert.encoder.layer.3.output.dense.weight": "model-00001-of-00010.safetensors",
175
+ "qformer.bert.encoder.layer.3.output_query.LayerNorm.bias": "model-00001-of-00010.safetensors",
176
+ "qformer.bert.encoder.layer.3.output_query.LayerNorm.weight": "model-00001-of-00010.safetensors",
177
+ "qformer.bert.encoder.layer.3.output_query.dense.bias": "model-00001-of-00010.safetensors",
178
+ "qformer.bert.encoder.layer.3.output_query.dense.weight": "model-00001-of-00010.safetensors",
179
+ "qformer.bert.encoder.layer.4.attention.output.LayerNorm.bias": "model-00001-of-00010.safetensors",
180
+ "qformer.bert.encoder.layer.4.attention.output.LayerNorm.weight": "model-00001-of-00010.safetensors",
181
+ "qformer.bert.encoder.layer.4.attention.output.dense.bias": "model-00001-of-00010.safetensors",
182
+ "qformer.bert.encoder.layer.4.attention.output.dense.weight": "model-00001-of-00010.safetensors",
183
+ "qformer.bert.encoder.layer.4.attention.self.key.bias": "model-00001-of-00010.safetensors",
184
+ "qformer.bert.encoder.layer.4.attention.self.key.weight": "model-00001-of-00010.safetensors",
185
+ "qformer.bert.encoder.layer.4.attention.self.query.bias": "model-00001-of-00010.safetensors",
186
+ "qformer.bert.encoder.layer.4.attention.self.query.weight": "model-00001-of-00010.safetensors",
187
+ "qformer.bert.encoder.layer.4.attention.self.value.bias": "model-00001-of-00010.safetensors",
188
+ "qformer.bert.encoder.layer.4.attention.self.value.weight": "model-00001-of-00010.safetensors",
189
+ "qformer.bert.encoder.layer.4.crossattention.output.LayerNorm.bias": "model-00001-of-00010.safetensors",
190
+ "qformer.bert.encoder.layer.4.crossattention.output.LayerNorm.weight": "model-00001-of-00010.safetensors",
191
+ "qformer.bert.encoder.layer.4.crossattention.output.dense.bias": "model-00001-of-00010.safetensors",
192
+ "qformer.bert.encoder.layer.4.crossattention.output.dense.weight": "model-00001-of-00010.safetensors",
193
+ "qformer.bert.encoder.layer.4.crossattention.self.key.bias": "model-00001-of-00010.safetensors",
194
+ "qformer.bert.encoder.layer.4.crossattention.self.key.weight": "model-00001-of-00010.safetensors",
195
+ "qformer.bert.encoder.layer.4.crossattention.self.query.bias": "model-00001-of-00010.safetensors",
196
+ "qformer.bert.encoder.layer.4.crossattention.self.query.weight": "model-00001-of-00010.safetensors",
197
+ "qformer.bert.encoder.layer.4.crossattention.self.value.bias": "model-00001-of-00010.safetensors",
198
+ "qformer.bert.encoder.layer.4.crossattention.self.value.weight": "model-00001-of-00010.safetensors",
199
+ "qformer.bert.encoder.layer.4.intermediate.dense.bias": "model-00001-of-00010.safetensors",
200
+ "qformer.bert.encoder.layer.4.intermediate.dense.weight": "model-00001-of-00010.safetensors",
201
+ "qformer.bert.encoder.layer.4.intermediate_query.dense.bias": "model-00001-of-00010.safetensors",
202
+ "qformer.bert.encoder.layer.4.intermediate_query.dense.weight": "model-00001-of-00010.safetensors",
203
+ "qformer.bert.encoder.layer.4.output.LayerNorm.bias": "model-00001-of-00010.safetensors",
204
+ "qformer.bert.encoder.layer.4.output.LayerNorm.weight": "model-00001-of-00010.safetensors",
205
+ "qformer.bert.encoder.layer.4.output.dense.bias": "model-00001-of-00010.safetensors",
206
+ "qformer.bert.encoder.layer.4.output.dense.weight": "model-00001-of-00010.safetensors",
207
+ "qformer.bert.encoder.layer.4.output_query.LayerNorm.bias": "model-00001-of-00010.safetensors",
208
+ "qformer.bert.encoder.layer.4.output_query.LayerNorm.weight": "model-00001-of-00010.safetensors",
209
+ "qformer.bert.encoder.layer.4.output_query.dense.bias": "model-00001-of-00010.safetensors",
210
+ "qformer.bert.encoder.layer.4.output_query.dense.weight": "model-00001-of-00010.safetensors",
211
+ "qformer.bert.encoder.layer.5.attention.output.LayerNorm.bias": "model-00001-of-00010.safetensors",
212
+ "qformer.bert.encoder.layer.5.attention.output.LayerNorm.weight": "model-00001-of-00010.safetensors",
213
+ "qformer.bert.encoder.layer.5.attention.output.dense.bias": "model-00001-of-00010.safetensors",
214
+ "qformer.bert.encoder.layer.5.attention.output.dense.weight": "model-00001-of-00010.safetensors",
215
+ "qformer.bert.encoder.layer.5.attention.self.key.bias": "model-00001-of-00010.safetensors",
216
+ "qformer.bert.encoder.layer.5.attention.self.key.weight": "model-00001-of-00010.safetensors",
217
+ "qformer.bert.encoder.layer.5.attention.self.query.bias": "model-00001-of-00010.safetensors",
218
+ "qformer.bert.encoder.layer.5.attention.self.query.weight": "model-00001-of-00010.safetensors",
219
+ "qformer.bert.encoder.layer.5.attention.self.value.bias": "model-00001-of-00010.safetensors",
220
+ "qformer.bert.encoder.layer.5.attention.self.value.weight": "model-00001-of-00010.safetensors",
221
+ "qformer.bert.encoder.layer.5.intermediate.dense.bias": "model-00001-of-00010.safetensors",
222
+ "qformer.bert.encoder.layer.5.intermediate.dense.weight": "model-00001-of-00010.safetensors",
223
+ "qformer.bert.encoder.layer.5.intermediate_query.dense.bias": "model-00001-of-00010.safetensors",
224
+ "qformer.bert.encoder.layer.5.intermediate_query.dense.weight": "model-00001-of-00010.safetensors",
225
+ "qformer.bert.encoder.layer.5.output.LayerNorm.bias": "model-00001-of-00010.safetensors",
226
+ "qformer.bert.encoder.layer.5.output.LayerNorm.weight": "model-00001-of-00010.safetensors",
227
+ "qformer.bert.encoder.layer.5.output.dense.bias": "model-00001-of-00010.safetensors",
228
+ "qformer.bert.encoder.layer.5.output.dense.weight": "model-00001-of-00010.safetensors",
229
+ "qformer.bert.encoder.layer.5.output_query.LayerNorm.bias": "model-00001-of-00010.safetensors",
230
+ "qformer.bert.encoder.layer.5.output_query.LayerNorm.weight": "model-00001-of-00010.safetensors",
231
+ "qformer.bert.encoder.layer.5.output_query.dense.bias": "model-00001-of-00010.safetensors",
232
+ "qformer.bert.encoder.layer.5.output_query.dense.weight": "model-00002-of-00010.safetensors",
233
+ "qformer.bert.encoder.layer.6.attention.output.LayerNorm.bias": "model-00002-of-00010.safetensors",
234
+ "qformer.bert.encoder.layer.6.attention.output.LayerNorm.weight": "model-00002-of-00010.safetensors",
235
+ "qformer.bert.encoder.layer.6.attention.output.dense.bias": "model-00002-of-00010.safetensors",
236
+ "qformer.bert.encoder.layer.6.attention.output.dense.weight": "model-00002-of-00010.safetensors",
237
+ "qformer.bert.encoder.layer.6.attention.self.key.bias": "model-00002-of-00010.safetensors",
238
+ "qformer.bert.encoder.layer.6.attention.self.key.weight": "model-00002-of-00010.safetensors",
239
+ "qformer.bert.encoder.layer.6.attention.self.query.bias": "model-00002-of-00010.safetensors",
240
+ "qformer.bert.encoder.layer.6.attention.self.query.weight": "model-00002-of-00010.safetensors",
241
+ "qformer.bert.encoder.layer.6.attention.self.value.bias": "model-00002-of-00010.safetensors",
242
+ "qformer.bert.encoder.layer.6.attention.self.value.weight": "model-00002-of-00010.safetensors",
243
+ "qformer.bert.encoder.layer.6.crossattention.output.LayerNorm.bias": "model-00002-of-00010.safetensors",
244
+ "qformer.bert.encoder.layer.6.crossattention.output.LayerNorm.weight": "model-00002-of-00010.safetensors",
245
+ "qformer.bert.encoder.layer.6.crossattention.output.dense.bias": "model-00002-of-00010.safetensors",
246
+ "qformer.bert.encoder.layer.6.crossattention.output.dense.weight": "model-00002-of-00010.safetensors",
247
+ "qformer.bert.encoder.layer.6.crossattention.self.key.bias": "model-00002-of-00010.safetensors",
248
+ "qformer.bert.encoder.layer.6.crossattention.self.key.weight": "model-00002-of-00010.safetensors",
249
+ "qformer.bert.encoder.layer.6.crossattention.self.query.bias": "model-00002-of-00010.safetensors",
250
+ "qformer.bert.encoder.layer.6.crossattention.self.query.weight": "model-00002-of-00010.safetensors",
251
+ "qformer.bert.encoder.layer.6.crossattention.self.value.bias": "model-00002-of-00010.safetensors",
252
+ "qformer.bert.encoder.layer.6.crossattention.self.value.weight": "model-00002-of-00010.safetensors",
253
+ "qformer.bert.encoder.layer.6.intermediate.dense.bias": "model-00002-of-00010.safetensors",
254
+ "qformer.bert.encoder.layer.6.intermediate.dense.weight": "model-00002-of-00010.safetensors",
255
+ "qformer.bert.encoder.layer.6.intermediate_query.dense.bias": "model-00002-of-00010.safetensors",
256
+ "qformer.bert.encoder.layer.6.intermediate_query.dense.weight": "model-00002-of-00010.safetensors",
257
+ "qformer.bert.encoder.layer.6.output.LayerNorm.bias": "model-00002-of-00010.safetensors",
258
+ "qformer.bert.encoder.layer.6.output.LayerNorm.weight": "model-00002-of-00010.safetensors",
259
+ "qformer.bert.encoder.layer.6.output.dense.bias": "model-00002-of-00010.safetensors",
260
+ "qformer.bert.encoder.layer.6.output.dense.weight": "model-00002-of-00010.safetensors",
261
+ "qformer.bert.encoder.layer.6.output_query.LayerNorm.bias": "model-00002-of-00010.safetensors",
262
+ "qformer.bert.encoder.layer.6.output_query.LayerNorm.weight": "model-00002-of-00010.safetensors",
263
+ "qformer.bert.encoder.layer.6.output_query.dense.bias": "model-00002-of-00010.safetensors",
264
+ "qformer.bert.encoder.layer.6.output_query.dense.weight": "model-00002-of-00010.safetensors",
265
+ "qformer.bert.encoder.layer.7.attention.output.LayerNorm.bias": "model-00002-of-00010.safetensors",
266
+ "qformer.bert.encoder.layer.7.attention.output.LayerNorm.weight": "model-00002-of-00010.safetensors",
267
+ "qformer.bert.encoder.layer.7.attention.output.dense.bias": "model-00002-of-00010.safetensors",
268
+ "qformer.bert.encoder.layer.7.attention.output.dense.weight": "model-00002-of-00010.safetensors",
269
+ "qformer.bert.encoder.layer.7.attention.self.key.bias": "model-00002-of-00010.safetensors",
270
+ "qformer.bert.encoder.layer.7.attention.self.key.weight": "model-00002-of-00010.safetensors",
271
+ "qformer.bert.encoder.layer.7.attention.self.query.bias": "model-00002-of-00010.safetensors",
272
+ "qformer.bert.encoder.layer.7.attention.self.query.weight": "model-00002-of-00010.safetensors",
273
+ "qformer.bert.encoder.layer.7.attention.self.value.bias": "model-00002-of-00010.safetensors",
274
+ "qformer.bert.encoder.layer.7.attention.self.value.weight": "model-00002-of-00010.safetensors",
275
+ "qformer.bert.encoder.layer.7.intermediate.dense.bias": "model-00002-of-00010.safetensors",
276
+ "qformer.bert.encoder.layer.7.intermediate.dense.weight": "model-00002-of-00010.safetensors",
277
+ "qformer.bert.encoder.layer.7.intermediate_query.dense.bias": "model-00002-of-00010.safetensors",
278
+ "qformer.bert.encoder.layer.7.intermediate_query.dense.weight": "model-00002-of-00010.safetensors",
279
+ "qformer.bert.encoder.layer.7.output.LayerNorm.bias": "model-00002-of-00010.safetensors",
280
+ "qformer.bert.encoder.layer.7.output.LayerNorm.weight": "model-00002-of-00010.safetensors",
281
+ "qformer.bert.encoder.layer.7.output.dense.bias": "model-00002-of-00010.safetensors",
282
+ "qformer.bert.encoder.layer.7.output.dense.weight": "model-00002-of-00010.safetensors",
283
+ "qformer.bert.encoder.layer.7.output_query.LayerNorm.bias": "model-00002-of-00010.safetensors",
284
+ "qformer.bert.encoder.layer.7.output_query.LayerNorm.weight": "model-00002-of-00010.safetensors",
285
+ "qformer.bert.encoder.layer.7.output_query.dense.bias": "model-00002-of-00010.safetensors",
286
+ "qformer.bert.encoder.layer.7.output_query.dense.weight": "model-00002-of-00010.safetensors",
287
+ "qformer.bert.encoder.layer.8.attention.output.LayerNorm.bias": "model-00002-of-00010.safetensors",
288
+ "qformer.bert.encoder.layer.8.attention.output.LayerNorm.weight": "model-00002-of-00010.safetensors",
289
+ "qformer.bert.encoder.layer.8.attention.output.dense.bias": "model-00002-of-00010.safetensors",
290
+ "qformer.bert.encoder.layer.8.attention.output.dense.weight": "model-00002-of-00010.safetensors",
291
+ "qformer.bert.encoder.layer.8.attention.self.key.bias": "model-00002-of-00010.safetensors",
292
+ "qformer.bert.encoder.layer.8.attention.self.key.weight": "model-00002-of-00010.safetensors",
293
+ "qformer.bert.encoder.layer.8.attention.self.query.bias": "model-00002-of-00010.safetensors",
294
+ "qformer.bert.encoder.layer.8.attention.self.query.weight": "model-00002-of-00010.safetensors",
295
+ "qformer.bert.encoder.layer.8.attention.self.value.bias": "model-00002-of-00010.safetensors",
296
+ "qformer.bert.encoder.layer.8.attention.self.value.weight": "model-00002-of-00010.safetensors",
297
+ "qformer.bert.encoder.layer.8.crossattention.output.LayerNorm.bias": "model-00002-of-00010.safetensors",
298
+ "qformer.bert.encoder.layer.8.crossattention.output.LayerNorm.weight": "model-00002-of-00010.safetensors",
299
+ "qformer.bert.encoder.layer.8.crossattention.output.dense.bias": "model-00002-of-00010.safetensors",
300
+ "qformer.bert.encoder.layer.8.crossattention.output.dense.weight": "model-00002-of-00010.safetensors",
301
+ "qformer.bert.encoder.layer.8.crossattention.self.key.bias": "model-00002-of-00010.safetensors",
302
+ "qformer.bert.encoder.layer.8.crossattention.self.key.weight": "model-00002-of-00010.safetensors",
303
+ "qformer.bert.encoder.layer.8.crossattention.self.query.bias": "model-00002-of-00010.safetensors",
304
+ "qformer.bert.encoder.layer.8.crossattention.self.query.weight": "model-00002-of-00010.safetensors",
305
+ "qformer.bert.encoder.layer.8.crossattention.self.value.bias": "model-00002-of-00010.safetensors",
306
+ "qformer.bert.encoder.layer.8.crossattention.self.value.weight": "model-00002-of-00010.safetensors",
307
+ "qformer.bert.encoder.layer.8.intermediate.dense.bias": "model-00002-of-00010.safetensors",
308
+ "qformer.bert.encoder.layer.8.intermediate.dense.weight": "model-00002-of-00010.safetensors",
309
+ "qformer.bert.encoder.layer.8.intermediate_query.dense.bias": "model-00002-of-00010.safetensors",
310
+ "qformer.bert.encoder.layer.8.intermediate_query.dense.weight": "model-00002-of-00010.safetensors",
311
+ "qformer.bert.encoder.layer.8.output.LayerNorm.bias": "model-00002-of-00010.safetensors",
312
+ "qformer.bert.encoder.layer.8.output.LayerNorm.weight": "model-00002-of-00010.safetensors",
313
+ "qformer.bert.encoder.layer.8.output.dense.bias": "model-00002-of-00010.safetensors",
314
+ "qformer.bert.encoder.layer.8.output.dense.weight": "model-00002-of-00010.safetensors",
315
+ "qformer.bert.encoder.layer.8.output_query.LayerNorm.bias": "model-00002-of-00010.safetensors",
316
+ "qformer.bert.encoder.layer.8.output_query.LayerNorm.weight": "model-00002-of-00010.safetensors",
317
+ "qformer.bert.encoder.layer.8.output_query.dense.bias": "model-00002-of-00010.safetensors",
318
+ "qformer.bert.encoder.layer.8.output_query.dense.weight": "model-00002-of-00010.safetensors",
319
+ "qformer.bert.encoder.layer.9.attention.output.LayerNorm.bias": "model-00002-of-00010.safetensors",
320
+ "qformer.bert.encoder.layer.9.attention.output.LayerNorm.weight": "model-00002-of-00010.safetensors",
321
+ "qformer.bert.encoder.layer.9.attention.output.dense.bias": "model-00002-of-00010.safetensors",
322
+ "qformer.bert.encoder.layer.9.attention.output.dense.weight": "model-00002-of-00010.safetensors",
323
+ "qformer.bert.encoder.layer.9.attention.self.key.bias": "model-00002-of-00010.safetensors",
324
+ "qformer.bert.encoder.layer.9.attention.self.key.weight": "model-00002-of-00010.safetensors",
325
+ "qformer.bert.encoder.layer.9.attention.self.query.bias": "model-00002-of-00010.safetensors",
326
+ "qformer.bert.encoder.layer.9.attention.self.query.weight": "model-00002-of-00010.safetensors",
327
+ "qformer.bert.encoder.layer.9.attention.self.value.bias": "model-00002-of-00010.safetensors",
328
+ "qformer.bert.encoder.layer.9.attention.self.value.weight": "model-00002-of-00010.safetensors",
329
+ "qformer.bert.encoder.layer.9.intermediate.dense.bias": "model-00002-of-00010.safetensors",
330
+ "qformer.bert.encoder.layer.9.intermediate.dense.weight": "model-00002-of-00010.safetensors",
331
+ "qformer.bert.encoder.layer.9.intermediate_query.dense.bias": "model-00002-of-00010.safetensors",
332
+ "qformer.bert.encoder.layer.9.intermediate_query.dense.weight": "model-00002-of-00010.safetensors",
333
+ "qformer.bert.encoder.layer.9.output.LayerNorm.bias": "model-00002-of-00010.safetensors",
334
+ "qformer.bert.encoder.layer.9.output.LayerNorm.weight": "model-00002-of-00010.safetensors",
335
+ "qformer.bert.encoder.layer.9.output.dense.bias": "model-00002-of-00010.safetensors",
336
+ "qformer.bert.encoder.layer.9.output.dense.weight": "model-00002-of-00010.safetensors",
337
+ "qformer.bert.encoder.layer.9.output_query.LayerNorm.bias": "model-00002-of-00010.safetensors",
338
+ "qformer.bert.encoder.layer.9.output_query.LayerNorm.weight": "model-00002-of-00010.safetensors",
339
+ "qformer.bert.encoder.layer.9.output_query.dense.bias": "model-00002-of-00010.safetensors",
340
+ "qformer.bert.encoder.layer.9.output_query.dense.weight": "model-00002-of-00010.safetensors",
341
+ "qformer.cls.predictions.bias": "model-00002-of-00010.safetensors",
342
+ "qformer.cls.predictions.decoder.bias": "model-00002-of-00010.safetensors",
343
+ "qformer.cls.predictions.decoder.weight": "model-00002-of-00010.safetensors",
344
+ "qformer.cls.predictions.transform.LayerNorm.bias": "model-00002-of-00010.safetensors",
345
+ "qformer.cls.predictions.transform.LayerNorm.weight": "model-00002-of-00010.safetensors",
346
+ "qformer.cls.predictions.transform.dense.bias": "model-00002-of-00010.safetensors",
347
+ "qformer.cls.predictions.transform.dense.weight": "model-00002-of-00010.safetensors",
348
+ "query_tokens": "model-00002-of-00010.safetensors",
349
+ "temporal_encoding.encoding.temp_embed": "model-00002-of-00010.safetensors",
350
+ "text_proj.bias": "model-00002-of-00010.safetensors",
351
+ "text_proj.weight": "model-00002-of-00010.safetensors",
352
+ "vision_proj.bias": "model-00002-of-00010.safetensors",
353
+ "vision_proj.weight": "model-00002-of-00010.safetensors",
354
+ "visual_encoder.blocks.0.attn.proj.bias": "model-00002-of-00010.safetensors",
355
+ "visual_encoder.blocks.0.attn.proj.weight": "model-00002-of-00010.safetensors",
356
+ "visual_encoder.blocks.0.attn.qkv.bias": "model-00002-of-00010.safetensors",
357
+ "visual_encoder.blocks.0.attn.qkv.weight": "model-00002-of-00010.safetensors",
358
+ "visual_encoder.blocks.0.mlp.fc1.bias": "model-00002-of-00010.safetensors",
359
+ "visual_encoder.blocks.0.mlp.fc1.weight": "model-00002-of-00010.safetensors",
360
+ "visual_encoder.blocks.0.mlp.fc2.bias": "model-00002-of-00010.safetensors",
361
+ "visual_encoder.blocks.0.mlp.fc2.weight": "model-00002-of-00010.safetensors",
362
+ "visual_encoder.blocks.0.norm1.bias": "model-00002-of-00010.safetensors",
363
+ "visual_encoder.blocks.0.norm1.weight": "model-00002-of-00010.safetensors",
364
+ "visual_encoder.blocks.0.norm2.bias": "model-00002-of-00010.safetensors",
365
+ "visual_encoder.blocks.0.norm2.weight": "model-00002-of-00010.safetensors",
366
+ "visual_encoder.blocks.1.attn.proj.bias": "model-00002-of-00010.safetensors",
367
+ "visual_encoder.blocks.1.attn.proj.weight": "model-00002-of-00010.safetensors",
368
+ "visual_encoder.blocks.1.attn.qkv.bias": "model-00002-of-00010.safetensors",
369
+ "visual_encoder.blocks.1.attn.qkv.weight": "model-00002-of-00010.safetensors",
370
+ "visual_encoder.blocks.1.mlp.fc1.bias": "model-00002-of-00010.safetensors",
371
+ "visual_encoder.blocks.1.mlp.fc1.weight": "model-00002-of-00010.safetensors",
372
+ "visual_encoder.blocks.1.mlp.fc2.bias": "model-00002-of-00010.safetensors",
373
+ "visual_encoder.blocks.1.mlp.fc2.weight": "model-00003-of-00010.safetensors",
374
+ "visual_encoder.blocks.1.norm1.bias": "model-00003-of-00010.safetensors",
375
+ "visual_encoder.blocks.1.norm1.weight": "model-00003-of-00010.safetensors",
376
+ "visual_encoder.blocks.1.norm2.bias": "model-00003-of-00010.safetensors",
377
+ "visual_encoder.blocks.1.norm2.weight": "model-00003-of-00010.safetensors",
378
+ "visual_encoder.blocks.10.attn.proj.bias": "model-00003-of-00010.safetensors",
379
+ "visual_encoder.blocks.10.attn.proj.weight": "model-00003-of-00010.safetensors",
380
+ "visual_encoder.blocks.10.attn.qkv.bias": "model-00003-of-00010.safetensors",
381
+ "visual_encoder.blocks.10.attn.qkv.weight": "model-00003-of-00010.safetensors",
382
+ "visual_encoder.blocks.10.mlp.fc1.bias": "model-00003-of-00010.safetensors",
383
+ "visual_encoder.blocks.10.mlp.fc1.weight": "model-00003-of-00010.safetensors",
384
+ "visual_encoder.blocks.10.mlp.fc2.bias": "model-00003-of-00010.safetensors",
385
+ "visual_encoder.blocks.10.mlp.fc2.weight": "model-00003-of-00010.safetensors",
386
+ "visual_encoder.blocks.10.norm1.bias": "model-00003-of-00010.safetensors",
387
+ "visual_encoder.blocks.10.norm1.weight": "model-00003-of-00010.safetensors",
388
+ "visual_encoder.blocks.10.norm2.bias": "model-00003-of-00010.safetensors",
389
+ "visual_encoder.blocks.10.norm2.weight": "model-00003-of-00010.safetensors",
390
+ "visual_encoder.blocks.11.attn.proj.bias": "model-00003-of-00010.safetensors",
391
+ "visual_encoder.blocks.11.attn.proj.weight": "model-00003-of-00010.safetensors",
392
+ "visual_encoder.blocks.11.attn.qkv.bias": "model-00003-of-00010.safetensors",
393
+ "visual_encoder.blocks.11.attn.qkv.weight": "model-00003-of-00010.safetensors",
394
+ "visual_encoder.blocks.11.mlp.fc1.bias": "model-00003-of-00010.safetensors",
395
+ "visual_encoder.blocks.11.mlp.fc1.weight": "model-00003-of-00010.safetensors",
396
+ "visual_encoder.blocks.11.mlp.fc2.bias": "model-00003-of-00010.safetensors",
397
+ "visual_encoder.blocks.11.mlp.fc2.weight": "model-00003-of-00010.safetensors",
398
+ "visual_encoder.blocks.11.norm1.bias": "model-00003-of-00010.safetensors",
399
+ "visual_encoder.blocks.11.norm1.weight": "model-00003-of-00010.safetensors",
400
+ "visual_encoder.blocks.11.norm2.bias": "model-00003-of-00010.safetensors",
401
+ "visual_encoder.blocks.11.norm2.weight": "model-00003-of-00010.safetensors",
402
+ "visual_encoder.blocks.12.attn.proj.bias": "model-00003-of-00010.safetensors",
403
+ "visual_encoder.blocks.12.attn.proj.weight": "model-00003-of-00010.safetensors",
404
+ "visual_encoder.blocks.12.attn.qkv.bias": "model-00003-of-00010.safetensors",
405
+ "visual_encoder.blocks.12.attn.qkv.weight": "model-00003-of-00010.safetensors",
406
+ "visual_encoder.blocks.12.mlp.fc1.bias": "model-00003-of-00010.safetensors",
407
+ "visual_encoder.blocks.12.mlp.fc1.weight": "model-00003-of-00010.safetensors",
408
+ "visual_encoder.blocks.12.mlp.fc2.bias": "model-00003-of-00010.safetensors",
409
+ "visual_encoder.blocks.12.mlp.fc2.weight": "model-00003-of-00010.safetensors",
410
+ "visual_encoder.blocks.12.norm1.bias": "model-00003-of-00010.safetensors",
411
+ "visual_encoder.blocks.12.norm1.weight": "model-00003-of-00010.safetensors",
412
+ "visual_encoder.blocks.12.norm2.bias": "model-00003-of-00010.safetensors",
413
+ "visual_encoder.blocks.12.norm2.weight": "model-00003-of-00010.safetensors",
414
+ "visual_encoder.blocks.13.attn.proj.bias": "model-00003-of-00010.safetensors",
415
+ "visual_encoder.blocks.13.attn.proj.weight": "model-00003-of-00010.safetensors",
416
+ "visual_encoder.blocks.13.attn.qkv.bias": "model-00003-of-00010.safetensors",
417
+ "visual_encoder.blocks.13.attn.qkv.weight": "model-00003-of-00010.safetensors",
418
+ "visual_encoder.blocks.13.mlp.fc1.bias": "model-00003-of-00010.safetensors",
419
+ "visual_encoder.blocks.13.mlp.fc1.weight": "model-00003-of-00010.safetensors",
420
+ "visual_encoder.blocks.13.mlp.fc2.bias": "model-00003-of-00010.safetensors",
421
+ "visual_encoder.blocks.13.mlp.fc2.weight": "model-00003-of-00010.safetensors",
422
+ "visual_encoder.blocks.13.norm1.bias": "model-00003-of-00010.safetensors",
423
+ "visual_encoder.blocks.13.norm1.weight": "model-00003-of-00010.safetensors",
424
+ "visual_encoder.blocks.13.norm2.bias": "model-00003-of-00010.safetensors",
425
+ "visual_encoder.blocks.13.norm2.weight": "model-00003-of-00010.safetensors",
426
+ "visual_encoder.blocks.14.attn.proj.bias": "model-00003-of-00010.safetensors",
427
+ "visual_encoder.blocks.14.attn.proj.weight": "model-00003-of-00010.safetensors",
428
+ "visual_encoder.blocks.14.attn.qkv.bias": "model-00003-of-00010.safetensors",
429
+ "visual_encoder.blocks.14.attn.qkv.weight": "model-00003-of-00010.safetensors",
430
+ "visual_encoder.blocks.14.mlp.fc1.bias": "model-00003-of-00010.safetensors",
431
+ "visual_encoder.blocks.14.mlp.fc1.weight": "model-00003-of-00010.safetensors",
432
+ "visual_encoder.blocks.14.mlp.fc2.bias": "model-00003-of-00010.safetensors",
433
+ "visual_encoder.blocks.14.mlp.fc2.weight": "model-00004-of-00010.safetensors",
434
+ "visual_encoder.blocks.14.norm1.bias": "model-00004-of-00010.safetensors",
435
+ "visual_encoder.blocks.14.norm1.weight": "model-00004-of-00010.safetensors",
436
+ "visual_encoder.blocks.14.norm2.bias": "model-00004-of-00010.safetensors",
437
+ "visual_encoder.blocks.14.norm2.weight": "model-00004-of-00010.safetensors",
438
+ "visual_encoder.blocks.15.attn.proj.bias": "model-00004-of-00010.safetensors",
439
+ "visual_encoder.blocks.15.attn.proj.weight": "model-00004-of-00010.safetensors",
440
+ "visual_encoder.blocks.15.attn.qkv.bias": "model-00004-of-00010.safetensors",
441
+ "visual_encoder.blocks.15.attn.qkv.weight": "model-00004-of-00010.safetensors",
442
+ "visual_encoder.blocks.15.mlp.fc1.bias": "model-00004-of-00010.safetensors",
443
+ "visual_encoder.blocks.15.mlp.fc1.weight": "model-00004-of-00010.safetensors",
444
+ "visual_encoder.blocks.15.mlp.fc2.bias": "model-00004-of-00010.safetensors",
445
+ "visual_encoder.blocks.15.mlp.fc2.weight": "model-00004-of-00010.safetensors",
446
+ "visual_encoder.blocks.15.norm1.bias": "model-00004-of-00010.safetensors",
447
+ "visual_encoder.blocks.15.norm1.weight": "model-00004-of-00010.safetensors",
448
+ "visual_encoder.blocks.15.norm2.bias": "model-00004-of-00010.safetensors",
449
+ "visual_encoder.blocks.15.norm2.weight": "model-00004-of-00010.safetensors",
450
+ "visual_encoder.blocks.16.attn.proj.bias": "model-00004-of-00010.safetensors",
451
+ "visual_encoder.blocks.16.attn.proj.weight": "model-00004-of-00010.safetensors",
452
+ "visual_encoder.blocks.16.attn.qkv.bias": "model-00004-of-00010.safetensors",
453
+ "visual_encoder.blocks.16.attn.qkv.weight": "model-00004-of-00010.safetensors",
454
+ "visual_encoder.blocks.16.mlp.fc1.bias": "model-00004-of-00010.safetensors",
455
+ "visual_encoder.blocks.16.mlp.fc1.weight": "model-00004-of-00010.safetensors",
456
+ "visual_encoder.blocks.16.mlp.fc2.bias": "model-00004-of-00010.safetensors",
457
+ "visual_encoder.blocks.16.mlp.fc2.weight": "model-00004-of-00010.safetensors",
458
+ "visual_encoder.blocks.16.norm1.bias": "model-00004-of-00010.safetensors",
459
+ "visual_encoder.blocks.16.norm1.weight": "model-00004-of-00010.safetensors",
460
+ "visual_encoder.blocks.16.norm2.bias": "model-00004-of-00010.safetensors",
461
+ "visual_encoder.blocks.16.norm2.weight": "model-00004-of-00010.safetensors",
462
+ "visual_encoder.blocks.17.attn.proj.bias": "model-00004-of-00010.safetensors",
463
+ "visual_encoder.blocks.17.attn.proj.weight": "model-00004-of-00010.safetensors",
464
+ "visual_encoder.blocks.17.attn.qkv.bias": "model-00004-of-00010.safetensors",
465
+ "visual_encoder.blocks.17.attn.qkv.weight": "model-00004-of-00010.safetensors",
466
+ "visual_encoder.blocks.17.mlp.fc1.bias": "model-00004-of-00010.safetensors",
467
+ "visual_encoder.blocks.17.mlp.fc1.weight": "model-00004-of-00010.safetensors",
468
+ "visual_encoder.blocks.17.mlp.fc2.bias": "model-00004-of-00010.safetensors",
469
+ "visual_encoder.blocks.17.mlp.fc2.weight": "model-00004-of-00010.safetensors",
470
+ "visual_encoder.blocks.17.norm1.bias": "model-00004-of-00010.safetensors",
471
+ "visual_encoder.blocks.17.norm1.weight": "model-00004-of-00010.safetensors",
472
+ "visual_encoder.blocks.17.norm2.bias": "model-00004-of-00010.safetensors",
473
+ "visual_encoder.blocks.17.norm2.weight": "model-00004-of-00010.safetensors",
474
+ "visual_encoder.blocks.18.attn.proj.bias": "model-00004-of-00010.safetensors",
475
+ "visual_encoder.blocks.18.attn.proj.weight": "model-00004-of-00010.safetensors",
476
+ "visual_encoder.blocks.18.attn.qkv.bias": "model-00004-of-00010.safetensors",
477
+ "visual_encoder.blocks.18.attn.qkv.weight": "model-00004-of-00010.safetensors",
478
+ "visual_encoder.blocks.18.mlp.fc1.bias": "model-00004-of-00010.safetensors",
479
+ "visual_encoder.blocks.18.mlp.fc1.weight": "model-00004-of-00010.safetensors",
480
+ "visual_encoder.blocks.18.mlp.fc2.bias": "model-00004-of-00010.safetensors",
481
+ "visual_encoder.blocks.18.mlp.fc2.weight": "model-00004-of-00010.safetensors",
482
+ "visual_encoder.blocks.18.norm1.bias": "model-00004-of-00010.safetensors",
483
+ "visual_encoder.blocks.18.norm1.weight": "model-00004-of-00010.safetensors",
484
+ "visual_encoder.blocks.18.norm2.bias": "model-00004-of-00010.safetensors",
485
+ "visual_encoder.blocks.18.norm2.weight": "model-00004-of-00010.safetensors",
486
+ "visual_encoder.blocks.19.attn.proj.bias": "model-00004-of-00010.safetensors",
487
+ "visual_encoder.blocks.19.attn.proj.weight": "model-00004-of-00010.safetensors",
488
+ "visual_encoder.blocks.19.attn.qkv.bias": "model-00004-of-00010.safetensors",
489
+ "visual_encoder.blocks.19.attn.qkv.weight": "model-00004-of-00010.safetensors",
490
+ "visual_encoder.blocks.19.mlp.fc1.bias": "model-00004-of-00010.safetensors",
491
+ "visual_encoder.blocks.19.mlp.fc1.weight": "model-00004-of-00010.safetensors",
492
+ "visual_encoder.blocks.19.mlp.fc2.bias": "model-00004-of-00010.safetensors",
493
+ "visual_encoder.blocks.19.mlp.fc2.weight": "model-00005-of-00010.safetensors",
494
+ "visual_encoder.blocks.19.norm1.bias": "model-00005-of-00010.safetensors",
495
+ "visual_encoder.blocks.19.norm1.weight": "model-00005-of-00010.safetensors",
496
+ "visual_encoder.blocks.19.norm2.bias": "model-00005-of-00010.safetensors",
497
+ "visual_encoder.blocks.19.norm2.weight": "model-00005-of-00010.safetensors",
498
+ "visual_encoder.blocks.2.attn.proj.bias": "model-00005-of-00010.safetensors",
499
+ "visual_encoder.blocks.2.attn.proj.weight": "model-00005-of-00010.safetensors",
500
+ "visual_encoder.blocks.2.attn.qkv.bias": "model-00005-of-00010.safetensors",
501
+ "visual_encoder.blocks.2.attn.qkv.weight": "model-00005-of-00010.safetensors",
502
+ "visual_encoder.blocks.2.mlp.fc1.bias": "model-00005-of-00010.safetensors",
503
+ "visual_encoder.blocks.2.mlp.fc1.weight": "model-00005-of-00010.safetensors",
504
+ "visual_encoder.blocks.2.mlp.fc2.bias": "model-00005-of-00010.safetensors",
505
+ "visual_encoder.blocks.2.mlp.fc2.weight": "model-00005-of-00010.safetensors",
506
+ "visual_encoder.blocks.2.norm1.bias": "model-00005-of-00010.safetensors",
507
+ "visual_encoder.blocks.2.norm1.weight": "model-00005-of-00010.safetensors",
508
+ "visual_encoder.blocks.2.norm2.bias": "model-00005-of-00010.safetensors",
509
+ "visual_encoder.blocks.2.norm2.weight": "model-00005-of-00010.safetensors",
510
+ "visual_encoder.blocks.20.attn.proj.bias": "model-00005-of-00010.safetensors",
511
+ "visual_encoder.blocks.20.attn.proj.weight": "model-00005-of-00010.safetensors",
512
+ "visual_encoder.blocks.20.attn.qkv.bias": "model-00005-of-00010.safetensors",
513
+ "visual_encoder.blocks.20.attn.qkv.weight": "model-00005-of-00010.safetensors",
514
+ "visual_encoder.blocks.20.mlp.fc1.bias": "model-00005-of-00010.safetensors",
515
+ "visual_encoder.blocks.20.mlp.fc1.weight": "model-00005-of-00010.safetensors",
516
+ "visual_encoder.blocks.20.mlp.fc2.bias": "model-00005-of-00010.safetensors",
517
+ "visual_encoder.blocks.20.mlp.fc2.weight": "model-00005-of-00010.safetensors",
518
+ "visual_encoder.blocks.20.norm1.bias": "model-00005-of-00010.safetensors",
519
+ "visual_encoder.blocks.20.norm1.weight": "model-00005-of-00010.safetensors",
520
+ "visual_encoder.blocks.20.norm2.bias": "model-00005-of-00010.safetensors",
521
+ "visual_encoder.blocks.20.norm2.weight": "model-00005-of-00010.safetensors",
522
+ "visual_encoder.blocks.21.attn.proj.bias": "model-00005-of-00010.safetensors",
523
+ "visual_encoder.blocks.21.attn.proj.weight": "model-00005-of-00010.safetensors",
524
+ "visual_encoder.blocks.21.attn.qkv.bias": "model-00005-of-00010.safetensors",
525
+ "visual_encoder.blocks.21.attn.qkv.weight": "model-00005-of-00010.safetensors",
526
+ "visual_encoder.blocks.21.mlp.fc1.bias": "model-00005-of-00010.safetensors",
527
+ "visual_encoder.blocks.21.mlp.fc1.weight": "model-00005-of-00010.safetensors",
528
+ "visual_encoder.blocks.21.mlp.fc2.bias": "model-00005-of-00010.safetensors",
529
+ "visual_encoder.blocks.21.mlp.fc2.weight": "model-00005-of-00010.safetensors",
530
+ "visual_encoder.blocks.21.norm1.bias": "model-00005-of-00010.safetensors",
531
+ "visual_encoder.blocks.21.norm1.weight": "model-00005-of-00010.safetensors",
532
+ "visual_encoder.blocks.21.norm2.bias": "model-00005-of-00010.safetensors",
533
+ "visual_encoder.blocks.21.norm2.weight": "model-00005-of-00010.safetensors",
534
+ "visual_encoder.blocks.22.attn.proj.bias": "model-00005-of-00010.safetensors",
535
+ "visual_encoder.blocks.22.attn.proj.weight": "model-00005-of-00010.safetensors",
536
+ "visual_encoder.blocks.22.attn.qkv.bias": "model-00005-of-00010.safetensors",
537
+ "visual_encoder.blocks.22.attn.qkv.weight": "model-00005-of-00010.safetensors",
538
+ "visual_encoder.blocks.22.mlp.fc1.bias": "model-00005-of-00010.safetensors",
539
+ "visual_encoder.blocks.22.mlp.fc1.weight": "model-00005-of-00010.safetensors",
540
+ "visual_encoder.blocks.22.mlp.fc2.bias": "model-00005-of-00010.safetensors",
541
+ "visual_encoder.blocks.22.mlp.fc2.weight": "model-00005-of-00010.safetensors",
542
+ "visual_encoder.blocks.22.norm1.bias": "model-00005-of-00010.safetensors",
543
+ "visual_encoder.blocks.22.norm1.weight": "model-00005-of-00010.safetensors",
544
+ "visual_encoder.blocks.22.norm2.bias": "model-00005-of-00010.safetensors",
545
+ "visual_encoder.blocks.22.norm2.weight": "model-00005-of-00010.safetensors",
546
+ "visual_encoder.blocks.23.attn.proj.bias": "model-00005-of-00010.safetensors",
547
+ "visual_encoder.blocks.23.attn.proj.weight": "model-00005-of-00010.safetensors",
548
+ "visual_encoder.blocks.23.attn.qkv.bias": "model-00005-of-00010.safetensors",
549
+ "visual_encoder.blocks.23.attn.qkv.weight": "model-00005-of-00010.safetensors",
550
+ "visual_encoder.blocks.23.mlp.fc1.bias": "model-00005-of-00010.safetensors",
551
+ "visual_encoder.blocks.23.mlp.fc1.weight": "model-00005-of-00010.safetensors",
552
+ "visual_encoder.blocks.23.mlp.fc2.bias": "model-00005-of-00010.safetensors",
553
+ "visual_encoder.blocks.23.mlp.fc2.weight": "model-00006-of-00010.safetensors",
554
+ "visual_encoder.blocks.23.norm1.bias": "model-00006-of-00010.safetensors",
555
+ "visual_encoder.blocks.23.norm1.weight": "model-00006-of-00010.safetensors",
556
+ "visual_encoder.blocks.23.norm2.bias": "model-00006-of-00010.safetensors",
557
+ "visual_encoder.blocks.23.norm2.weight": "model-00006-of-00010.safetensors",
558
+ "visual_encoder.blocks.24.attn.proj.bias": "model-00006-of-00010.safetensors",
559
+ "visual_encoder.blocks.24.attn.proj.weight": "model-00006-of-00010.safetensors",
560
+ "visual_encoder.blocks.24.attn.qkv.bias": "model-00006-of-00010.safetensors",
561
+ "visual_encoder.blocks.24.attn.qkv.weight": "model-00006-of-00010.safetensors",
562
+ "visual_encoder.blocks.24.mlp.fc1.bias": "model-00006-of-00010.safetensors",
563
+ "visual_encoder.blocks.24.mlp.fc1.weight": "model-00006-of-00010.safetensors",
564
+ "visual_encoder.blocks.24.mlp.fc2.bias": "model-00006-of-00010.safetensors",
565
+ "visual_encoder.blocks.24.mlp.fc2.weight": "model-00006-of-00010.safetensors",
566
+ "visual_encoder.blocks.24.norm1.bias": "model-00006-of-00010.safetensors",
567
+ "visual_encoder.blocks.24.norm1.weight": "model-00006-of-00010.safetensors",
568
+ "visual_encoder.blocks.24.norm2.bias": "model-00006-of-00010.safetensors",
569
+ "visual_encoder.blocks.24.norm2.weight": "model-00006-of-00010.safetensors",
570
+ "visual_encoder.blocks.25.attn.proj.bias": "model-00006-of-00010.safetensors",
571
+ "visual_encoder.blocks.25.attn.proj.weight": "model-00006-of-00010.safetensors",
572
+ "visual_encoder.blocks.25.attn.qkv.bias": "model-00006-of-00010.safetensors",
573
+ "visual_encoder.blocks.25.attn.qkv.weight": "model-00006-of-00010.safetensors",
574
+ "visual_encoder.blocks.25.mlp.fc1.bias": "model-00006-of-00010.safetensors",
575
+ "visual_encoder.blocks.25.mlp.fc1.weight": "model-00006-of-00010.safetensors",
576
+ "visual_encoder.blocks.25.mlp.fc2.bias": "model-00006-of-00010.safetensors",
577
+ "visual_encoder.blocks.25.mlp.fc2.weight": "model-00006-of-00010.safetensors",
578
+ "visual_encoder.blocks.25.norm1.bias": "model-00006-of-00010.safetensors",
579
+ "visual_encoder.blocks.25.norm1.weight": "model-00006-of-00010.safetensors",
580
+ "visual_encoder.blocks.25.norm2.bias": "model-00006-of-00010.safetensors",
581
+ "visual_encoder.blocks.25.norm2.weight": "model-00006-of-00010.safetensors",
582
+ "visual_encoder.blocks.26.attn.proj.bias": "model-00006-of-00010.safetensors",
583
+ "visual_encoder.blocks.26.attn.proj.weight": "model-00006-of-00010.safetensors",
584
+ "visual_encoder.blocks.26.attn.qkv.bias": "model-00006-of-00010.safetensors",
585
+ "visual_encoder.blocks.26.attn.qkv.weight": "model-00006-of-00010.safetensors",
586
+ "visual_encoder.blocks.26.mlp.fc1.bias": "model-00006-of-00010.safetensors",
587
+ "visual_encoder.blocks.26.mlp.fc1.weight": "model-00006-of-00010.safetensors",
588
+ "visual_encoder.blocks.26.mlp.fc2.bias": "model-00006-of-00010.safetensors",
589
+ "visual_encoder.blocks.26.mlp.fc2.weight": "model-00006-of-00010.safetensors",
590
+ "visual_encoder.blocks.26.norm1.bias": "model-00006-of-00010.safetensors",
591
+ "visual_encoder.blocks.26.norm1.weight": "model-00006-of-00010.safetensors",
592
+ "visual_encoder.blocks.26.norm2.bias": "model-00006-of-00010.safetensors",
593
+ "visual_encoder.blocks.26.norm2.weight": "model-00006-of-00010.safetensors",
594
+ "visual_encoder.blocks.27.attn.proj.bias": "model-00006-of-00010.safetensors",
595
+ "visual_encoder.blocks.27.attn.proj.weight": "model-00006-of-00010.safetensors",
596
+ "visual_encoder.blocks.27.attn.qkv.bias": "model-00006-of-00010.safetensors",
597
+ "visual_encoder.blocks.27.attn.qkv.weight": "model-00006-of-00010.safetensors",
598
+ "visual_encoder.blocks.27.mlp.fc1.bias": "model-00006-of-00010.safetensors",
599
+ "visual_encoder.blocks.27.mlp.fc1.weight": "model-00006-of-00010.safetensors",
600
+ "visual_encoder.blocks.27.mlp.fc2.bias": "model-00006-of-00010.safetensors",
601
+ "visual_encoder.blocks.27.mlp.fc2.weight": "model-00006-of-00010.safetensors",
602
+ "visual_encoder.blocks.27.norm1.bias": "model-00006-of-00010.safetensors",
603
+ "visual_encoder.blocks.27.norm1.weight": "model-00006-of-00010.safetensors",
604
+ "visual_encoder.blocks.27.norm2.bias": "model-00006-of-00010.safetensors",
605
+ "visual_encoder.blocks.27.norm2.weight": "model-00006-of-00010.safetensors",
606
+ "visual_encoder.blocks.28.attn.proj.bias": "model-00006-of-00010.safetensors",
607
+ "visual_encoder.blocks.28.attn.proj.weight": "model-00006-of-00010.safetensors",
608
+ "visual_encoder.blocks.28.attn.qkv.bias": "model-00006-of-00010.safetensors",
609
+ "visual_encoder.blocks.28.attn.qkv.weight": "model-00006-of-00010.safetensors",
610
+ "visual_encoder.blocks.28.mlp.fc1.bias": "model-00006-of-00010.safetensors",
611
+ "visual_encoder.blocks.28.mlp.fc1.weight": "model-00006-of-00010.safetensors",
612
+ "visual_encoder.blocks.28.mlp.fc2.bias": "model-00006-of-00010.safetensors",
613
+ "visual_encoder.blocks.28.mlp.fc2.weight": "model-00007-of-00010.safetensors",
614
+ "visual_encoder.blocks.28.norm1.bias": "model-00007-of-00010.safetensors",
615
+ "visual_encoder.blocks.28.norm1.weight": "model-00007-of-00010.safetensors",
616
+ "visual_encoder.blocks.28.norm2.bias": "model-00007-of-00010.safetensors",
617
+ "visual_encoder.blocks.28.norm2.weight": "model-00007-of-00010.safetensors",
618
+ "visual_encoder.blocks.29.attn.proj.bias": "model-00007-of-00010.safetensors",
619
+ "visual_encoder.blocks.29.attn.proj.weight": "model-00007-of-00010.safetensors",
620
+ "visual_encoder.blocks.29.attn.qkv.bias": "model-00007-of-00010.safetensors",
621
+ "visual_encoder.blocks.29.attn.qkv.weight": "model-00007-of-00010.safetensors",
622
+ "visual_encoder.blocks.29.mlp.fc1.bias": "model-00007-of-00010.safetensors",
623
+ "visual_encoder.blocks.29.mlp.fc1.weight": "model-00007-of-00010.safetensors",
624
+ "visual_encoder.blocks.29.mlp.fc2.bias": "model-00007-of-00010.safetensors",
625
+ "visual_encoder.blocks.29.mlp.fc2.weight": "model-00007-of-00010.safetensors",
626
+ "visual_encoder.blocks.29.norm1.bias": "model-00007-of-00010.safetensors",
627
+ "visual_encoder.blocks.29.norm1.weight": "model-00007-of-00010.safetensors",
628
+ "visual_encoder.blocks.29.norm2.bias": "model-00007-of-00010.safetensors",
629
+ "visual_encoder.blocks.29.norm2.weight": "model-00007-of-00010.safetensors",
630
+ "visual_encoder.blocks.3.attn.proj.bias": "model-00007-of-00010.safetensors",
631
+ "visual_encoder.blocks.3.attn.proj.weight": "model-00007-of-00010.safetensors",
632
+ "visual_encoder.blocks.3.attn.qkv.bias": "model-00007-of-00010.safetensors",
633
+ "visual_encoder.blocks.3.attn.qkv.weight": "model-00007-of-00010.safetensors",
634
+ "visual_encoder.blocks.3.mlp.fc1.bias": "model-00007-of-00010.safetensors",
635
+ "visual_encoder.blocks.3.mlp.fc1.weight": "model-00007-of-00010.safetensors",
636
+ "visual_encoder.blocks.3.mlp.fc2.bias": "model-00007-of-00010.safetensors",
637
+ "visual_encoder.blocks.3.mlp.fc2.weight": "model-00007-of-00010.safetensors",
638
+ "visual_encoder.blocks.3.norm1.bias": "model-00007-of-00010.safetensors",
639
+ "visual_encoder.blocks.3.norm1.weight": "model-00007-of-00010.safetensors",
640
+ "visual_encoder.blocks.3.norm2.bias": "model-00007-of-00010.safetensors",
641
+ "visual_encoder.blocks.3.norm2.weight": "model-00007-of-00010.safetensors",
642
+ "visual_encoder.blocks.30.attn.proj.bias": "model-00007-of-00010.safetensors",
643
+ "visual_encoder.blocks.30.attn.proj.weight": "model-00007-of-00010.safetensors",
644
+ "visual_encoder.blocks.30.attn.qkv.bias": "model-00007-of-00010.safetensors",
645
+ "visual_encoder.blocks.30.attn.qkv.weight": "model-00007-of-00010.safetensors",
646
+ "visual_encoder.blocks.30.mlp.fc1.bias": "model-00007-of-00010.safetensors",
647
+ "visual_encoder.blocks.30.mlp.fc1.weight": "model-00007-of-00010.safetensors",
648
+ "visual_encoder.blocks.30.mlp.fc2.bias": "model-00007-of-00010.safetensors",
649
+ "visual_encoder.blocks.30.mlp.fc2.weight": "model-00007-of-00010.safetensors",
650
+ "visual_encoder.blocks.30.norm1.bias": "model-00007-of-00010.safetensors",
651
+ "visual_encoder.blocks.30.norm1.weight": "model-00007-of-00010.safetensors",
652
+ "visual_encoder.blocks.30.norm2.bias": "model-00007-of-00010.safetensors",
653
+ "visual_encoder.blocks.30.norm2.weight": "model-00007-of-00010.safetensors",
654
+ "visual_encoder.blocks.31.attn.proj.bias": "model-00007-of-00010.safetensors",
655
+ "visual_encoder.blocks.31.attn.proj.weight": "model-00007-of-00010.safetensors",
656
+ "visual_encoder.blocks.31.attn.qkv.bias": "model-00007-of-00010.safetensors",
657
+ "visual_encoder.blocks.31.attn.qkv.weight": "model-00007-of-00010.safetensors",
658
+ "visual_encoder.blocks.31.mlp.fc1.bias": "model-00007-of-00010.safetensors",
659
+ "visual_encoder.blocks.31.mlp.fc1.weight": "model-00007-of-00010.safetensors",
660
+ "visual_encoder.blocks.31.mlp.fc2.bias": "model-00007-of-00010.safetensors",
661
+ "visual_encoder.blocks.31.mlp.fc2.weight": "model-00007-of-00010.safetensors",
662
+ "visual_encoder.blocks.31.norm1.bias": "model-00007-of-00010.safetensors",
663
+ "visual_encoder.blocks.31.norm1.weight": "model-00007-of-00010.safetensors",
664
+ "visual_encoder.blocks.31.norm2.bias": "model-00007-of-00010.safetensors",
665
+ "visual_encoder.blocks.31.norm2.weight": "model-00007-of-00010.safetensors",
666
+ "visual_encoder.blocks.32.attn.proj.bias": "model-00007-of-00010.safetensors",
667
+ "visual_encoder.blocks.32.attn.proj.weight": "model-00007-of-00010.safetensors",
668
+ "visual_encoder.blocks.32.attn.qkv.bias": "model-00007-of-00010.safetensors",
669
+ "visual_encoder.blocks.32.attn.qkv.weight": "model-00007-of-00010.safetensors",
670
+ "visual_encoder.blocks.32.mlp.fc1.bias": "model-00007-of-00010.safetensors",
671
+ "visual_encoder.blocks.32.mlp.fc1.weight": "model-00007-of-00010.safetensors",
672
+ "visual_encoder.blocks.32.mlp.fc2.bias": "model-00007-of-00010.safetensors",
673
+ "visual_encoder.blocks.32.mlp.fc2.weight": "model-00008-of-00010.safetensors",
674
+ "visual_encoder.blocks.32.norm1.bias": "model-00008-of-00010.safetensors",
675
+ "visual_encoder.blocks.32.norm1.weight": "model-00008-of-00010.safetensors",
676
+ "visual_encoder.blocks.32.norm2.bias": "model-00008-of-00010.safetensors",
677
+ "visual_encoder.blocks.32.norm2.weight": "model-00008-of-00010.safetensors",
678
+ "visual_encoder.blocks.33.attn.proj.bias": "model-00008-of-00010.safetensors",
679
+ "visual_encoder.blocks.33.attn.proj.weight": "model-00008-of-00010.safetensors",
680
+ "visual_encoder.blocks.33.attn.qkv.bias": "model-00008-of-00010.safetensors",
681
+ "visual_encoder.blocks.33.attn.qkv.weight": "model-00008-of-00010.safetensors",
682
+ "visual_encoder.blocks.33.mlp.fc1.bias": "model-00008-of-00010.safetensors",
683
+ "visual_encoder.blocks.33.mlp.fc1.weight": "model-00008-of-00010.safetensors",
684
+ "visual_encoder.blocks.33.mlp.fc2.bias": "model-00008-of-00010.safetensors",
685
+ "visual_encoder.blocks.33.mlp.fc2.weight": "model-00008-of-00010.safetensors",
686
+ "visual_encoder.blocks.33.norm1.bias": "model-00008-of-00010.safetensors",
687
+ "visual_encoder.blocks.33.norm1.weight": "model-00008-of-00010.safetensors",
688
+ "visual_encoder.blocks.33.norm2.bias": "model-00008-of-00010.safetensors",
689
+ "visual_encoder.blocks.33.norm2.weight": "model-00008-of-00010.safetensors",
690
+ "visual_encoder.blocks.34.attn.proj.bias": "model-00008-of-00010.safetensors",
691
+ "visual_encoder.blocks.34.attn.proj.weight": "model-00008-of-00010.safetensors",
692
+ "visual_encoder.blocks.34.attn.qkv.bias": "model-00008-of-00010.safetensors",
693
+ "visual_encoder.blocks.34.attn.qkv.weight": "model-00008-of-00010.safetensors",
694
+ "visual_encoder.blocks.34.mlp.fc1.bias": "model-00008-of-00010.safetensors",
695
+ "visual_encoder.blocks.34.mlp.fc1.weight": "model-00008-of-00010.safetensors",
696
+ "visual_encoder.blocks.34.mlp.fc2.bias": "model-00008-of-00010.safetensors",
697
+ "visual_encoder.blocks.34.mlp.fc2.weight": "model-00008-of-00010.safetensors",
698
+ "visual_encoder.blocks.34.norm1.bias": "model-00008-of-00010.safetensors",
699
+ "visual_encoder.blocks.34.norm1.weight": "model-00008-of-00010.safetensors",
700
+ "visual_encoder.blocks.34.norm2.bias": "model-00008-of-00010.safetensors",
701
+ "visual_encoder.blocks.34.norm2.weight": "model-00008-of-00010.safetensors",
702
+ "visual_encoder.blocks.35.attn.proj.bias": "model-00008-of-00010.safetensors",
703
+ "visual_encoder.blocks.35.attn.proj.weight": "model-00008-of-00010.safetensors",
704
+ "visual_encoder.blocks.35.attn.qkv.bias": "model-00008-of-00010.safetensors",
705
+ "visual_encoder.blocks.35.attn.qkv.weight": "model-00008-of-00010.safetensors",
706
+ "visual_encoder.blocks.35.mlp.fc1.bias": "model-00008-of-00010.safetensors",
707
+ "visual_encoder.blocks.35.mlp.fc1.weight": "model-00008-of-00010.safetensors",
708
+ "visual_encoder.blocks.35.mlp.fc2.bias": "model-00008-of-00010.safetensors",
709
+ "visual_encoder.blocks.35.mlp.fc2.weight": "model-00008-of-00010.safetensors",
710
+ "visual_encoder.blocks.35.norm1.bias": "model-00008-of-00010.safetensors",
711
+ "visual_encoder.blocks.35.norm1.weight": "model-00008-of-00010.safetensors",
712
+ "visual_encoder.blocks.35.norm2.bias": "model-00008-of-00010.safetensors",
713
+ "visual_encoder.blocks.35.norm2.weight": "model-00008-of-00010.safetensors",
714
+ "visual_encoder.blocks.36.attn.proj.bias": "model-00008-of-00010.safetensors",
715
+ "visual_encoder.blocks.36.attn.proj.weight": "model-00008-of-00010.safetensors",
716
+ "visual_encoder.blocks.36.attn.qkv.bias": "model-00008-of-00010.safetensors",
717
+ "visual_encoder.blocks.36.attn.qkv.weight": "model-00008-of-00010.safetensors",
718
+ "visual_encoder.blocks.36.mlp.fc1.bias": "model-00008-of-00010.safetensors",
719
+ "visual_encoder.blocks.36.mlp.fc1.weight": "model-00008-of-00010.safetensors",
720
+ "visual_encoder.blocks.36.mlp.fc2.bias": "model-00008-of-00010.safetensors",
721
+ "visual_encoder.blocks.36.mlp.fc2.weight": "model-00008-of-00010.safetensors",
722
+ "visual_encoder.blocks.36.norm1.bias": "model-00008-of-00010.safetensors",
723
+ "visual_encoder.blocks.36.norm1.weight": "model-00008-of-00010.safetensors",
724
+ "visual_encoder.blocks.36.norm2.bias": "model-00008-of-00010.safetensors",
725
+ "visual_encoder.blocks.36.norm2.weight": "model-00008-of-00010.safetensors",
726
+ "visual_encoder.blocks.37.attn.proj.bias": "model-00008-of-00010.safetensors",
727
+ "visual_encoder.blocks.37.attn.proj.weight": "model-00008-of-00010.safetensors",
728
+ "visual_encoder.blocks.37.attn.qkv.bias": "model-00008-of-00010.safetensors",
729
+ "visual_encoder.blocks.37.attn.qkv.weight": "model-00008-of-00010.safetensors",
730
+ "visual_encoder.blocks.37.mlp.fc1.bias": "model-00008-of-00010.safetensors",
731
+ "visual_encoder.blocks.37.mlp.fc1.weight": "model-00008-of-00010.safetensors",
732
+ "visual_encoder.blocks.37.mlp.fc2.bias": "model-00008-of-00010.safetensors",
733
+ "visual_encoder.blocks.37.mlp.fc2.weight": "model-00009-of-00010.safetensors",
734
+ "visual_encoder.blocks.37.norm1.bias": "model-00009-of-00010.safetensors",
735
+ "visual_encoder.blocks.37.norm1.weight": "model-00009-of-00010.safetensors",
736
+ "visual_encoder.blocks.37.norm2.bias": "model-00009-of-00010.safetensors",
737
+ "visual_encoder.blocks.37.norm2.weight": "model-00009-of-00010.safetensors",
738
+ "visual_encoder.blocks.38.attn.proj.bias": "model-00009-of-00010.safetensors",
739
+ "visual_encoder.blocks.38.attn.proj.weight": "model-00009-of-00010.safetensors",
740
+ "visual_encoder.blocks.38.attn.qkv.bias": "model-00009-of-00010.safetensors",
741
+ "visual_encoder.blocks.38.attn.qkv.weight": "model-00009-of-00010.safetensors",
742
+ "visual_encoder.blocks.38.mlp.fc1.bias": "model-00009-of-00010.safetensors",
743
+ "visual_encoder.blocks.38.mlp.fc1.weight": "model-00009-of-00010.safetensors",
744
+ "visual_encoder.blocks.38.mlp.fc2.bias": "model-00009-of-00010.safetensors",
745
+ "visual_encoder.blocks.38.mlp.fc2.weight": "model-00009-of-00010.safetensors",
746
+ "visual_encoder.blocks.38.norm1.bias": "model-00009-of-00010.safetensors",
747
+ "visual_encoder.blocks.38.norm1.weight": "model-00009-of-00010.safetensors",
748
+ "visual_encoder.blocks.38.norm2.bias": "model-00009-of-00010.safetensors",
749
+ "visual_encoder.blocks.38.norm2.weight": "model-00009-of-00010.safetensors",
750
+ "visual_encoder.blocks.4.attn.proj.bias": "model-00009-of-00010.safetensors",
751
+ "visual_encoder.blocks.4.attn.proj.weight": "model-00009-of-00010.safetensors",
752
+ "visual_encoder.blocks.4.attn.qkv.bias": "model-00009-of-00010.safetensors",
753
+ "visual_encoder.blocks.4.attn.qkv.weight": "model-00009-of-00010.safetensors",
754
+ "visual_encoder.blocks.4.mlp.fc1.bias": "model-00009-of-00010.safetensors",
755
+ "visual_encoder.blocks.4.mlp.fc1.weight": "model-00009-of-00010.safetensors",
756
+ "visual_encoder.blocks.4.mlp.fc2.bias": "model-00009-of-00010.safetensors",
757
+ "visual_encoder.blocks.4.mlp.fc2.weight": "model-00009-of-00010.safetensors",
758
+ "visual_encoder.blocks.4.norm1.bias": "model-00009-of-00010.safetensors",
759
+ "visual_encoder.blocks.4.norm1.weight": "model-00009-of-00010.safetensors",
760
+ "visual_encoder.blocks.4.norm2.bias": "model-00009-of-00010.safetensors",
761
+ "visual_encoder.blocks.4.norm2.weight": "model-00009-of-00010.safetensors",
762
+ "visual_encoder.blocks.5.attn.proj.bias": "model-00009-of-00010.safetensors",
763
+ "visual_encoder.blocks.5.attn.proj.weight": "model-00009-of-00010.safetensors",
764
+ "visual_encoder.blocks.5.attn.qkv.bias": "model-00009-of-00010.safetensors",
765
+ "visual_encoder.blocks.5.attn.qkv.weight": "model-00009-of-00010.safetensors",
766
+ "visual_encoder.blocks.5.mlp.fc1.bias": "model-00009-of-00010.safetensors",
767
+ "visual_encoder.blocks.5.mlp.fc1.weight": "model-00009-of-00010.safetensors",
768
+ "visual_encoder.blocks.5.mlp.fc2.bias": "model-00009-of-00010.safetensors",
769
+ "visual_encoder.blocks.5.mlp.fc2.weight": "model-00009-of-00010.safetensors",
770
+ "visual_encoder.blocks.5.norm1.bias": "model-00009-of-00010.safetensors",
771
+ "visual_encoder.blocks.5.norm1.weight": "model-00009-of-00010.safetensors",
772
+ "visual_encoder.blocks.5.norm2.bias": "model-00009-of-00010.safetensors",
773
+ "visual_encoder.blocks.5.norm2.weight": "model-00009-of-00010.safetensors",
774
+ "visual_encoder.blocks.6.attn.proj.bias": "model-00009-of-00010.safetensors",
775
+ "visual_encoder.blocks.6.attn.proj.weight": "model-00009-of-00010.safetensors",
776
+ "visual_encoder.blocks.6.attn.qkv.bias": "model-00009-of-00010.safetensors",
777
+ "visual_encoder.blocks.6.attn.qkv.weight": "model-00009-of-00010.safetensors",
778
+ "visual_encoder.blocks.6.mlp.fc1.bias": "model-00009-of-00010.safetensors",
779
+ "visual_encoder.blocks.6.mlp.fc1.weight": "model-00009-of-00010.safetensors",
780
+ "visual_encoder.blocks.6.mlp.fc2.bias": "model-00009-of-00010.safetensors",
781
+ "visual_encoder.blocks.6.mlp.fc2.weight": "model-00009-of-00010.safetensors",
782
+ "visual_encoder.blocks.6.norm1.bias": "model-00009-of-00010.safetensors",
783
+ "visual_encoder.blocks.6.norm1.weight": "model-00009-of-00010.safetensors",
784
+ "visual_encoder.blocks.6.norm2.bias": "model-00009-of-00010.safetensors",
785
+ "visual_encoder.blocks.6.norm2.weight": "model-00009-of-00010.safetensors",
786
+ "visual_encoder.blocks.7.attn.proj.bias": "model-00009-of-00010.safetensors",
787
+ "visual_encoder.blocks.7.attn.proj.weight": "model-00009-of-00010.safetensors",
788
+ "visual_encoder.blocks.7.attn.qkv.bias": "model-00009-of-00010.safetensors",
789
+ "visual_encoder.blocks.7.attn.qkv.weight": "model-00009-of-00010.safetensors",
790
+ "visual_encoder.blocks.7.mlp.fc1.bias": "model-00009-of-00010.safetensors",
791
+ "visual_encoder.blocks.7.mlp.fc1.weight": "model-00009-of-00010.safetensors",
792
+ "visual_encoder.blocks.7.mlp.fc2.bias": "model-00009-of-00010.safetensors",
793
+ "visual_encoder.blocks.7.mlp.fc2.weight": "model-00010-of-00010.safetensors",
794
+ "visual_encoder.blocks.7.norm1.bias": "model-00010-of-00010.safetensors",
795
+ "visual_encoder.blocks.7.norm1.weight": "model-00010-of-00010.safetensors",
796
+ "visual_encoder.blocks.7.norm2.bias": "model-00010-of-00010.safetensors",
797
+ "visual_encoder.blocks.7.norm2.weight": "model-00010-of-00010.safetensors",
798
+ "visual_encoder.blocks.8.attn.proj.bias": "model-00010-of-00010.safetensors",
799
+ "visual_encoder.blocks.8.attn.proj.weight": "model-00010-of-00010.safetensors",
800
+ "visual_encoder.blocks.8.attn.qkv.bias": "model-00010-of-00010.safetensors",
801
+ "visual_encoder.blocks.8.attn.qkv.weight": "model-00010-of-00010.safetensors",
802
+ "visual_encoder.blocks.8.mlp.fc1.bias": "model-00010-of-00010.safetensors",
803
+ "visual_encoder.blocks.8.mlp.fc1.weight": "model-00010-of-00010.safetensors",
804
+ "visual_encoder.blocks.8.mlp.fc2.bias": "model-00010-of-00010.safetensors",
805
+ "visual_encoder.blocks.8.mlp.fc2.weight": "model-00010-of-00010.safetensors",
806
+ "visual_encoder.blocks.8.norm1.bias": "model-00010-of-00010.safetensors",
807
+ "visual_encoder.blocks.8.norm1.weight": "model-00010-of-00010.safetensors",
808
+ "visual_encoder.blocks.8.norm2.bias": "model-00010-of-00010.safetensors",
809
+ "visual_encoder.blocks.8.norm2.weight": "model-00010-of-00010.safetensors",
810
+ "visual_encoder.blocks.9.attn.proj.bias": "model-00010-of-00010.safetensors",
811
+ "visual_encoder.blocks.9.attn.proj.weight": "model-00010-of-00010.safetensors",
812
+ "visual_encoder.blocks.9.attn.qkv.bias": "model-00010-of-00010.safetensors",
813
+ "visual_encoder.blocks.9.attn.qkv.weight": "model-00010-of-00010.safetensors",
814
+ "visual_encoder.blocks.9.mlp.fc1.bias": "model-00010-of-00010.safetensors",
815
+ "visual_encoder.blocks.9.mlp.fc1.weight": "model-00010-of-00010.safetensors",
816
+ "visual_encoder.blocks.9.mlp.fc2.bias": "model-00010-of-00010.safetensors",
817
+ "visual_encoder.blocks.9.mlp.fc2.weight": "model-00010-of-00010.safetensors",
818
+ "visual_encoder.blocks.9.norm1.bias": "model-00010-of-00010.safetensors",
819
+ "visual_encoder.blocks.9.norm1.weight": "model-00010-of-00010.safetensors",
820
+ "visual_encoder.blocks.9.norm2.bias": "model-00010-of-00010.safetensors",
821
+ "visual_encoder.blocks.9.norm2.weight": "model-00010-of-00010.safetensors",
822
+ "visual_encoder.cls_token": "model-00010-of-00010.safetensors",
823
+ "visual_encoder.patch_embed.proj.bias": "model-00010-of-00010.safetensors",
824
+ "visual_encoder.patch_embed.proj.weight": "model-00010-of-00010.safetensors",
825
+ "visual_encoder.pos_embed": "model-00010-of-00010.safetensors"
826
+ }
827
+ }
modeling_embed1.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Cosmos-Embed1 text+video embedder."""
17
+
18
+ import math
19
+ from copy import deepcopy
20
+
21
+ import torch
22
+ from einops import rearrange
23
+ from torch import nn
24
+ from torch.nn import functional as F
25
+ from transformers import AutoModel, PreTrainedModel
26
+
27
+ from .configuration_embed1 import CosmosEmbed1Config
28
+ from .modeling_outputs import TextEmbedderOutput, TextVideoEmbedderOutput, VideoEmbedderOutput
29
+ from .modeling_qformer import BertLMHeadModel, load_qformer
30
+ from .modeling_utils import EncodingFactory, rank0_first
31
+ from .modeling_vit import EvaViTG
32
+
33
+
34
+ class CosmosEmbed1(PreTrainedModel):
35
+ config_class = CosmosEmbed1Config
36
+
37
+ def __init__(self, config: CosmosEmbed1Config) -> None:
38
+ """Cosmos-Embed1 video embedder constructor.
39
+
40
+ Args:
41
+ config (CosmosEmbed1Config): Model configuration.
42
+ """
43
+ super().__init__(config)
44
+
45
+ self.embed_dim = config.embed_dim
46
+ self.num_query_tokens = config.num_query_tokens
47
+ self.num_video_frames = config.num_video_frames
48
+ self.temporal_encoding_type = config.temporal_encoding_type
49
+ self.resolution = config.resolution
50
+ self.vocab_size = config.vocab_size
51
+ self.transformer_engine = config.transformer_engine
52
+ self.use_fp8 = config.use_fp8
53
+
54
+ # visual encoder initialization
55
+ self.register_buffer(
56
+ "normalization_mean",
57
+ torch.tensor([0.485, 0.456, 0.406]).view(1, 1, 3, 1, 1),
58
+ persistent=False,
59
+ )
60
+ self.register_buffer(
61
+ "normalization_std",
62
+ torch.tensor([0.229, 0.224, 0.225]).view(1, 1, 3, 1, 1),
63
+ persistent=False,
64
+ )
65
+ self.visual_encoder = EvaViTG(
66
+ img_size=self.resolution,
67
+ transformer_engine=self.transformer_engine,
68
+ use_fp8=self.use_fp8,
69
+ )
70
+ self.ln_vision = nn.LayerNorm(self.visual_encoder.embed_dim)
71
+
72
+ # qformer initialization
73
+ self.qformer, self.query_tokens = self._init_qformer(
74
+ num_query_tokens=self.num_query_tokens,
75
+ encoder_width=self.visual_encoder.embed_dim,
76
+ vocab_size=self.vocab_size,
77
+ )
78
+ # self.qformer.
79
+ state_dict = self.qformer.state_dict()
80
+ for name, param in self.qformer.named_parameters():
81
+ if "_query" in name:
82
+ key_orig = name.replace("_query", "")
83
+ param.data.copy_(state_dict[key_orig])
84
+
85
+ # temporal encoding
86
+ self.temporal_encoding = EncodingFactory(
87
+ self.temporal_encoding_type,
88
+ embed_dim=self.visual_encoder.embed_dim,
89
+ max_len=self.num_video_frames,
90
+ )
91
+
92
+ # output projections
93
+ self.vision_proj = nn.Linear(self.qformer.config.hidden_size, self.embed_dim)
94
+ self.text_proj = nn.Linear(self.qformer.config.hidden_size, self.embed_dim)
95
+ self.itm_proj = nn.Linear(self.qformer.config.hidden_size, 2)
96
+ # initialize logit scale/bias like SigLIP (as per Table 4 in https://arxiv.org/pdf/2303.15343)
97
+ self.logit_scale = nn.Parameter(torch.tensor(math.log(10.0)))
98
+ self.logit_bias = nn.Parameter(torch.tensor(-10.0))
99
+
100
+ @property
101
+ def hidden_dim(self) -> int:
102
+ return self.visual_encoder.embed_dim
103
+
104
+ @torch.jit.ignore
105
+ def no_weight_decay(self) -> set:
106
+ ret = {"logit_scale", "logit_bias"}
107
+ return ret
108
+
109
+ def forward(
110
+ self,
111
+ videos: torch.FloatTensor,
112
+ input_ids: torch.LongTensor,
113
+ attention_mask: torch.FloatTensor,
114
+ ) -> TextVideoEmbedderOutput:
115
+ """Forward function for `ComosEmbed1`.
116
+
117
+ Args:
118
+ videos (`torch.Tensor` of shape `(batch_size, num_frames, RGB, height, width)`):
119
+ batched videos with fixed number of RGB frames.
120
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
121
+ Indices of input sequence tokens in the vocabulary.
122
+ Indices can be obtained by using [`AutoTokenizer`, `CosmosEmbed1Tokenizer`].
123
+ attention_mask: (`torch.Tensor` of shape `(batch_size, sequence_length)`):
124
+ Mask to avoid performing attention on padding token indices.
125
+ Mask values select in `[0, 1]`.
126
+ - 1 for tokens that are **not masked**.
127
+ - 0 for tokens that are **masked**.
128
+ """
129
+ video_output = self.get_video_embeddings(videos)
130
+ text_output = self.get_text_embeddings(input_ids, attention_mask)
131
+ return TextVideoEmbedderOutput(**video_output, **text_output)
132
+
133
+ def get_video_embeddings(self, videos: torch.Tensor) -> VideoEmbedderOutput:
134
+ videos = (videos - self.normalization_mean) / self.normalization_std
135
+ batch_size, num_frames, _, H, W = videos.shape
136
+ frame_batch = rearrange(videos, "b t c h w -> (b t) c h w")
137
+
138
+ # process video frames through ViT
139
+ visual_embs = self.visual_encoder(frame_batch)
140
+ visual_embs = self.ln_vision(visual_embs)
141
+ visual_embs = rearrange(
142
+ visual_embs,
143
+ "(b t) k d -> b t k d",
144
+ b=batch_size,
145
+ t=num_frames,
146
+ k=visual_embs.size(1),
147
+ d=visual_embs.size(2),
148
+ )
149
+
150
+ # add temporal encoding
151
+ visual_embs = self.temporal_encoding(visual_embs)
152
+
153
+ # Q-Former cross-attention
154
+ encoder_hidden_states = rearrange(visual_embs, "b t k d -> b (t k) d")
155
+ encoder_attention_mask = torch.ones(encoder_hidden_states.size()[:-1], dtype=torch.long).to(videos.device)
156
+ query_tokens = self.query_tokens.expand(encoder_hidden_states.size(0), -1, -1)
157
+ visual_query_output = self.qformer.bert(
158
+ query_embeds=query_tokens,
159
+ encoder_hidden_states=encoder_hidden_states,
160
+ encoder_attention_mask=encoder_attention_mask,
161
+ use_cache=True,
162
+ return_dict=True,
163
+ )
164
+
165
+ visual_cls_tokens = visual_query_output.last_hidden_state.mean(dim=1, keepdim=False)
166
+ visual_proj = self.vision_proj(visual_cls_tokens)
167
+ visual_proj = F.normalize(visual_proj, dim=-1)
168
+
169
+ # reshape visual embs to (B,T,H,W,D), to confirm with expected output.
170
+ # separate out the frame-level cls tokens if necessary.
171
+ frame_cls_tokens, visual_embs = visual_embs[:, :, 0:1], visual_embs[:, :, 1:]
172
+ h = H // self.visual_encoder.patch_size
173
+ w = W // self.visual_encoder.patch_size
174
+ visual_embs = rearrange(visual_embs, "b t (h w) d -> b t h w d", h=h, w=w)
175
+
176
+ return VideoEmbedderOutput(
177
+ visual_proj=visual_proj,
178
+ visual_embs=visual_embs,
179
+ visual_query_output=visual_query_output,
180
+ visual_cls_tokens=visual_cls_tokens,
181
+ frame_cls_tokens=frame_cls_tokens,
182
+ )
183
+
184
+ def get_text_embeddings(
185
+ self,
186
+ input_ids: torch.LongTensor,
187
+ attention_mask: torch.FloatTensor,
188
+ ) -> TextEmbedderOutput:
189
+ text_query_output = self.qformer.bert(
190
+ input_ids=input_ids,
191
+ attention_mask=attention_mask.to(dtype=self.query_tokens.dtype),
192
+ return_dict=True,
193
+ )
194
+ text_proj = text_query_output.last_hidden_state[:, 0, :]
195
+ text_proj = self.text_proj(text_proj)
196
+ text_proj = F.normalize(text_proj, dim=-1)
197
+
198
+ return TextEmbedderOutput(
199
+ text_proj=text_proj,
200
+ text_embs=text_query_output.last_hidden_state,
201
+ text_query_output=text_query_output,
202
+ )
203
+
204
+ @classmethod
205
+ @rank0_first
206
+ def _init_qformer(
207
+ cls: "CosmosEmbed1",
208
+ num_query_tokens: int,
209
+ encoder_width: int,
210
+ vocab_size: int,
211
+ hidden_size: int = 768,
212
+ ) -> tuple[BertLMHeadModel, nn.Parameter]:
213
+ """Convenience function for initializing QFormer module."""
214
+ qformer = load_qformer(
215
+ num_query_tokens=num_query_tokens,
216
+ encoder_width=encoder_width,
217
+ hidden_size=hidden_size,
218
+ vocab_size=vocab_size,
219
+ )
220
+ query_tokens = nn.Parameter(torch.zeros(1, num_query_tokens, hidden_size))
221
+ query_tokens.data.normal_(mean=0.0, std=0.02)
222
+ return qformer, query_tokens
223
+
224
+ @classmethod
225
+ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
226
+ # Get config from kwargs or load from pretrained path
227
+ config = kwargs.get("config", None)
228
+ if config is None:
229
+ config = CosmosEmbed1Config.from_pretrained(pretrained_model_name_or_path)
230
+
231
+ if config.transformer_engine:
232
+ config_no_te = deepcopy(config)
233
+ config_no_te.transformer_engine = False
234
+ config_no_te.use_fp8 = False # Also disable FP8 for the base model
235
+
236
+ # Remove 'config' from kwargs to avoid conflict, we'll pass config_no_te
237
+ kwargs_no_te = deepcopy(kwargs)
238
+ kwargs_no_te["config"] = config_no_te
239
+
240
+ # Load standard (non-TE) model & weights
241
+ base_model = super().from_pretrained(pretrained_model_name_or_path, **kwargs_no_te)
242
+ base_state_dict = base_model.state_dict()
243
+
244
+ # Now build the TE version of the model
245
+ model_with_te = cls(config=config)
246
+
247
+ # Load weights from non-TE model
248
+ missing, unexpected = model_with_te.load_state_dict(base_state_dict, strict=False)
249
+
250
+ # Optional debug log
251
+ if missing:
252
+ print(f"[TransformerEngine] Missing keys: {missing}")
253
+ if unexpected:
254
+ print(f"[TransformerEngine] Unexpected keys: {unexpected}")
255
+
256
+ return model_with_te
257
+ else:
258
+ return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
259
+
260
+
261
+ AutoModel.register(CosmosEmbed1Config, CosmosEmbed1)
modeling_outputs.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Output definitions for Cosmos-Embed1."""
17
+
18
+ from dataclasses import dataclass
19
+ from typing import Optional
20
+
21
+ import torch
22
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, ModelOutput
23
+
24
+
25
+ @dataclass
26
+ class TextEmbedderOutput(ModelOutput):
27
+ """Output of a video embedder branch `get_text_embeddings` function.
28
+
29
+ Attrs:
30
+ text_proj (`torch.FloatTensor` of shape `(batch_size, num_visual_embs, embed_dim)` or `(batch_size, embed_dim)`:
31
+ text (video-aligned) projected embeddings from text branch.
32
+ text_embs (`torch.FloatTensor` of shape `(batch_size, ...)`:
33
+ text tokens from text branch.
34
+ text_query_output (`transformer.modeling_outputs.CausalLMOutputWithCrossAttentions`):
35
+ Useful text branch intermediate outputs like hidden states, past key values, attentions etc.
36
+ """
37
+
38
+ text_proj: Optional[torch.FloatTensor] = None
39
+ text_embs: Optional[torch.FloatTensor] = None
40
+ text_query_output: Optional[CausalLMOutputWithCrossAttentions] = None
41
+
42
+
43
+ @dataclass
44
+ class VideoEmbedderOutput(ModelOutput):
45
+ """Output of a video embedder branch `get_video_embeddings` function.
46
+
47
+ Attrs:
48
+ visual_proj (`torch.FloatTensor` of shape `(batch_size, embed_dim)`):
49
+ visual (text-aligned) projected embeddings from visual branch.
50
+ visual_embs (`torch.FloatTensor` of shape `(batch_size, num_frames, height, width, encoder_dim)`):
51
+ per-frame dense visual embeddings from visual encoder.
52
+ visual_cls_tokens (`torch.FloatTensor` of shape `(batch_size, qformer_dim)`):
53
+ visual pooled tokens from visual branch prior to projection and normalization.
54
+ frame_cls_tokens (`torch.FloatTensor` of shape `(batch_size, num_frames, encoder_dim)`):
55
+ per-frame cls tokens from visual encoder.
56
+ visual_query_output (`transformer.modeling_outputs.CausalLMOutputWithCrossAttentions`):
57
+ Useful visual branch intermediate outputs like hidden states, past key values, attentions etc.
58
+ """
59
+
60
+ visual_proj: Optional[torch.FloatTensor] = None
61
+ visual_embs: Optional[torch.FloatTensor] = None
62
+ visual_cls_tokens: Optional[torch.FloatTensor] = None
63
+ frame_cls_tokens: Optional[torch.FloatTensor] = None
64
+ visual_query_output: Optional[CausalLMOutputWithCrossAttentions] = None
65
+
66
+
67
+ @dataclass
68
+ class TextVideoEmbedderOutput(VideoEmbedderOutput, TextEmbedderOutput):
69
+ """Merged class of `VideoEmbedderOutput` and `TextEmbedderOutput`."""
modeling_qformer.py ADDED
@@ -0,0 +1,1060 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Copyright (c) 2023, salesforce.com, inc.
17
+ # All rights reserved.
18
+ # SPDX-License-Identifier: BSD-3-Clause
19
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
20
+
21
+ """Q-Former module.
22
+
23
+ Code adapted from:
24
+ https://github.com/salesforce/LAVIS/blob/main/lavis/models/blip2_models/Qformer.py
25
+ """
26
+
27
+ import math
28
+ from logging import getLogger
29
+ from typing import Literal, Tuple
30
+
31
+ import torch
32
+ import torch.utils.checkpoint
33
+ from torch import Tensor, device, nn
34
+ from torch.nn import CrossEntropyLoss
35
+ from transformers import GenerationMixin
36
+ from transformers.activations import ACT2FN
37
+ from transformers.modeling_outputs import (
38
+ BaseModelOutputWithPastAndCrossAttentions,
39
+ BaseModelOutputWithPoolingAndCrossAttentions,
40
+ CausalLMOutputWithCrossAttentions,
41
+ )
42
+ from transformers.modeling_utils import (
43
+ PreTrainedModel,
44
+ apply_chunking_to_forward,
45
+ find_pruneable_heads_and_indices,
46
+ prune_linear_layer,
47
+ )
48
+ from transformers.models.bert.configuration_bert import BertConfig
49
+
50
+ logger = getLogger(__file__)
51
+
52
+
53
+ class BertEmbeddings(nn.Module):
54
+ """Construct the embeddings from word and position embeddings."""
55
+
56
+ def __init__(self, config):
57
+ super().__init__()
58
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
59
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
60
+
61
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
62
+ # any TensorFlow checkpoint file
63
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
64
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
65
+
66
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
67
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
68
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
69
+
70
+ self.config = config
71
+
72
+ def forward(
73
+ self,
74
+ input_ids=None,
75
+ position_ids=None,
76
+ query_embeds=None,
77
+ past_key_values_length=0,
78
+ ):
79
+ if input_ids is not None:
80
+ seq_length = input_ids.size()[1]
81
+ else:
82
+ seq_length = 0
83
+
84
+ if position_ids is None:
85
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone()
86
+
87
+ if input_ids is not None:
88
+ embeddings = self.word_embeddings(input_ids)
89
+ if self.position_embedding_type == "absolute":
90
+ position_embeddings = self.position_embeddings(position_ids)
91
+ embeddings = embeddings + position_embeddings
92
+
93
+ if query_embeds is not None:
94
+ embeddings = torch.cat((query_embeds, embeddings), dim=1)
95
+ else:
96
+ embeddings = query_embeds
97
+
98
+ embeddings = self.LayerNorm(embeddings)
99
+ embeddings = self.dropout(embeddings)
100
+ return embeddings
101
+
102
+
103
+ # TODO: add more efficient attention kernels like FlashAttention V2/V3.
104
+ class BertSelfAttention(nn.Module):
105
+ def __init__(self, config, is_cross_attention):
106
+ super().__init__()
107
+ self.config = config
108
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
109
+ raise ValueError(
110
+ "The hidden size (%d) is not a multiple of the number of attention "
111
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
112
+ )
113
+
114
+ self.num_attention_heads = config.num_attention_heads
115
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
116
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
117
+
118
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
119
+ if is_cross_attention:
120
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
121
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
122
+ else:
123
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
124
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
125
+
126
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
127
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
128
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
129
+ self.max_position_embeddings = config.max_position_embeddings
130
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
131
+ self.save_attention = False
132
+
133
+ def save_attn_gradients(self, attn_gradients):
134
+ self.attn_gradients = attn_gradients
135
+
136
+ def get_attn_gradients(self):
137
+ return self.attn_gradients
138
+
139
+ def save_attention_map(self, attention_map):
140
+ self.attention_map = attention_map
141
+
142
+ def get_attention_map(self):
143
+ return self.attention_map
144
+
145
+ def transpose_for_scores(self, x):
146
+ new_x_shape = x.size()[:-1] + (
147
+ self.num_attention_heads,
148
+ self.attention_head_size,
149
+ )
150
+ x = x.view(*new_x_shape)
151
+ return x.permute(0, 2, 1, 3)
152
+
153
+ def forward(
154
+ self,
155
+ hidden_states,
156
+ attention_mask=None,
157
+ head_mask=None,
158
+ encoder_hidden_states=None,
159
+ encoder_attention_mask=None,
160
+ past_key_value=None,
161
+ output_attentions=False,
162
+ ):
163
+ # If this is instantiated as a cross-attention module, the keys
164
+ # and values come from an encoder; the attention mask needs to be
165
+ # such that the encoder's padding tokens are not attended to.
166
+ is_cross_attention = encoder_hidden_states is not None
167
+
168
+ if is_cross_attention:
169
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
170
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
171
+ attention_mask = encoder_attention_mask
172
+ elif past_key_value is not None:
173
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
174
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
175
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
176
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
177
+ else:
178
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
179
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
180
+
181
+ mixed_query_layer = self.query(hidden_states)
182
+
183
+ query_layer = self.transpose_for_scores(mixed_query_layer)
184
+
185
+ past_key_value = (key_layer, value_layer)
186
+
187
+ # Take the dot product between "query" and "key" to get the raw attention scores.
188
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
189
+
190
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
191
+ seq_length = hidden_states.size()[1]
192
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
193
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
194
+ distance = position_ids_l - position_ids_r
195
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
196
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
197
+
198
+ if self.position_embedding_type == "relative_key":
199
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
200
+ attention_scores = attention_scores + relative_position_scores
201
+ elif self.position_embedding_type == "relative_key_query":
202
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
203
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
204
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
205
+
206
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
207
+ if attention_mask is not None:
208
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
209
+ attention_scores = attention_scores + attention_mask
210
+
211
+ # Normalize the attention scores to probabilities.
212
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
213
+
214
+ if is_cross_attention and self.save_attention:
215
+ self.save_attention_map(attention_probs)
216
+ attention_probs.register_hook(self.save_attn_gradients)
217
+
218
+ # This is actually dropping out entire tokens to attend to, which might
219
+ # seem a bit unusual, but is taken from the original Transformer paper.
220
+ attention_probs_dropped = self.dropout(attention_probs)
221
+
222
+ # Mask heads if we want to
223
+ if head_mask is not None:
224
+ attention_probs_dropped = attention_probs_dropped * head_mask
225
+
226
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
227
+
228
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
229
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
230
+ context_layer = context_layer.view(*new_context_layer_shape)
231
+
232
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
233
+
234
+ outputs = outputs + (past_key_value,)
235
+ return outputs
236
+
237
+
238
+ class BertSelfOutput(nn.Module):
239
+ def __init__(self, config):
240
+ super().__init__()
241
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
242
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
243
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
244
+
245
+ def forward(self, hidden_states, input_tensor):
246
+ hidden_states = self.dense(hidden_states)
247
+ hidden_states = self.dropout(hidden_states)
248
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
249
+ return hidden_states
250
+
251
+
252
+ class BertAttention(nn.Module):
253
+ def __init__(self, config, is_cross_attention=False):
254
+ super().__init__()
255
+ self.self = BertSelfAttention(config, is_cross_attention)
256
+ self.output = BertSelfOutput(config)
257
+ self.pruned_heads = set()
258
+
259
+ def prune_heads(self, heads):
260
+ if len(heads) == 0:
261
+ return
262
+ heads, index = find_pruneable_heads_and_indices(
263
+ heads,
264
+ self.self.num_attention_heads,
265
+ self.self.attention_head_size,
266
+ self.pruned_heads,
267
+ )
268
+
269
+ # Prune linear layers
270
+ self.self.query = prune_linear_layer(self.self.query, index)
271
+ self.self.key = prune_linear_layer(self.self.key, index)
272
+ self.self.value = prune_linear_layer(self.self.value, index)
273
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
274
+
275
+ # Update hyper params and store pruned heads
276
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
277
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
278
+ self.pruned_heads = self.pruned_heads.union(heads)
279
+
280
+ def forward(
281
+ self,
282
+ hidden_states,
283
+ attention_mask=None,
284
+ head_mask=None,
285
+ encoder_hidden_states=None,
286
+ encoder_attention_mask=None,
287
+ past_key_value=None,
288
+ output_attentions=False,
289
+ ):
290
+ self_outputs = self.self(
291
+ hidden_states,
292
+ attention_mask,
293
+ head_mask,
294
+ encoder_hidden_states,
295
+ encoder_attention_mask,
296
+ past_key_value,
297
+ output_attentions,
298
+ )
299
+ attention_output = self.output(self_outputs[0], hidden_states)
300
+
301
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
302
+ return outputs
303
+
304
+
305
+ class BertIntermediate(nn.Module):
306
+ def __init__(self, config):
307
+ super().__init__()
308
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
309
+ if isinstance(config.hidden_act, str):
310
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
311
+ else:
312
+ self.intermediate_act_fn = config.hidden_act
313
+
314
+ def forward(self, hidden_states):
315
+ hidden_states = self.dense(hidden_states)
316
+ hidden_states = self.intermediate_act_fn(hidden_states)
317
+ return hidden_states
318
+
319
+
320
+ class BertOutput(nn.Module):
321
+ def __init__(self, config):
322
+ super().__init__()
323
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
324
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
325
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
326
+
327
+ def forward(self, hidden_states, input_tensor):
328
+ hidden_states = self.dense(hidden_states)
329
+ hidden_states = self.dropout(hidden_states)
330
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
331
+ return hidden_states
332
+
333
+
334
+ class BertLayer(nn.Module):
335
+ def __init__(self, config, layer_num):
336
+ super().__init__()
337
+ self.config = config
338
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
339
+ self.seq_len_dim = 1
340
+ self.attention = BertAttention(config)
341
+ self.layer_num = layer_num
342
+ if self.config.add_cross_attention and layer_num % self.config.cross_attention_freq == 0:
343
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
344
+ self.has_cross_attention = True
345
+ else:
346
+ self.has_cross_attention = False
347
+ self.intermediate = BertIntermediate(config)
348
+ self.output = BertOutput(config)
349
+
350
+ self.intermediate_query = BertIntermediate(config)
351
+ self.output_query = BertOutput(config)
352
+
353
+ def forward(
354
+ self,
355
+ hidden_states,
356
+ attention_mask=None,
357
+ head_mask=None,
358
+ encoder_hidden_states=None,
359
+ encoder_attention_mask=None,
360
+ past_key_value=None,
361
+ output_attentions=False,
362
+ query_length=0,
363
+ ):
364
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
365
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
366
+ self_attention_outputs = self.attention(
367
+ hidden_states,
368
+ attention_mask,
369
+ head_mask,
370
+ output_attentions=output_attentions,
371
+ past_key_value=self_attn_past_key_value,
372
+ )
373
+ attention_output = self_attention_outputs[0]
374
+ outputs = self_attention_outputs[1:-1]
375
+
376
+ present_key_value = self_attention_outputs[-1]
377
+
378
+ if query_length > 0:
379
+ query_attention_output = attention_output[:, :query_length, :]
380
+
381
+ if self.has_cross_attention:
382
+ assert (
383
+ encoder_hidden_states is not None
384
+ ), "encoder_hidden_states must be given for cross-attention layers"
385
+ cross_attention_outputs = self.crossattention(
386
+ query_attention_output,
387
+ attention_mask,
388
+ head_mask,
389
+ encoder_hidden_states,
390
+ encoder_attention_mask,
391
+ output_attentions=output_attentions,
392
+ )
393
+ query_attention_output = cross_attention_outputs[0]
394
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
395
+
396
+ layer_output = apply_chunking_to_forward(
397
+ self.feed_forward_chunk_query,
398
+ self.chunk_size_feed_forward,
399
+ self.seq_len_dim,
400
+ query_attention_output,
401
+ )
402
+ if attention_output.shape[1] > query_length:
403
+ layer_output_text = apply_chunking_to_forward(
404
+ self.feed_forward_chunk,
405
+ self.chunk_size_feed_forward,
406
+ self.seq_len_dim,
407
+ attention_output[:, query_length:, :],
408
+ )
409
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
410
+ else:
411
+ layer_output = apply_chunking_to_forward(
412
+ self.feed_forward_chunk,
413
+ self.chunk_size_feed_forward,
414
+ self.seq_len_dim,
415
+ attention_output,
416
+ )
417
+ outputs = (layer_output,) + outputs
418
+
419
+ outputs = outputs + (present_key_value,)
420
+
421
+ return outputs
422
+
423
+ def feed_forward_chunk(self, attention_output):
424
+ intermediate_output = self.intermediate(attention_output)
425
+ layer_output = self.output(intermediate_output, attention_output)
426
+ return layer_output
427
+
428
+ def feed_forward_chunk_query(self, attention_output):
429
+ intermediate_output = self.intermediate_query(attention_output)
430
+ layer_output = self.output_query(intermediate_output, attention_output)
431
+ return layer_output
432
+
433
+
434
+ class BertEncoder(nn.Module):
435
+ def __init__(self, config):
436
+ super().__init__()
437
+ self.config = config
438
+ self.layer = nn.ModuleList([BertLayer(config, i) for i in range(config.num_hidden_layers)])
439
+
440
+ def forward(
441
+ self,
442
+ hidden_states,
443
+ attention_mask=None,
444
+ head_mask=None,
445
+ encoder_hidden_states=None,
446
+ encoder_attention_mask=None,
447
+ past_key_values=None,
448
+ use_cache=None,
449
+ output_attentions=False,
450
+ output_hidden_states=False,
451
+ return_dict=True,
452
+ query_length=0,
453
+ ):
454
+ all_hidden_states = () if output_hidden_states else None
455
+ all_self_attentions = () if output_attentions else None
456
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
457
+
458
+ next_decoder_cache = () if use_cache else None
459
+
460
+ for i in range(self.config.num_hidden_layers):
461
+ layer_module = self.layer[i]
462
+ if output_hidden_states:
463
+ all_hidden_states = all_hidden_states + (hidden_states,)
464
+
465
+ layer_head_mask = head_mask[i] if head_mask is not None else None
466
+ past_key_value = past_key_values[i] if past_key_values is not None else None
467
+
468
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
469
+ if use_cache:
470
+ logger.warning(
471
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
472
+ )
473
+ use_cache = False
474
+
475
+ def create_custom_forward(module):
476
+ def custom_forward(*inputs):
477
+ return module(*inputs, past_key_value, output_attentions, query_length)
478
+
479
+ return custom_forward
480
+
481
+ layer_outputs = torch.utils.checkpoint.checkpoint(
482
+ create_custom_forward(layer_module),
483
+ hidden_states,
484
+ attention_mask,
485
+ layer_head_mask,
486
+ encoder_hidden_states,
487
+ encoder_attention_mask,
488
+ )
489
+ else:
490
+ layer_outputs = layer_module(
491
+ hidden_states,
492
+ attention_mask,
493
+ layer_head_mask,
494
+ encoder_hidden_states,
495
+ encoder_attention_mask,
496
+ past_key_value,
497
+ output_attentions,
498
+ query_length,
499
+ )
500
+
501
+ hidden_states = layer_outputs[0]
502
+ if use_cache:
503
+ next_decoder_cache += (layer_outputs[-1],)
504
+ if output_attentions:
505
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
506
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
507
+
508
+ if output_hidden_states:
509
+ all_hidden_states = all_hidden_states + (hidden_states,)
510
+
511
+ if not return_dict:
512
+ return tuple(
513
+ v
514
+ for v in [
515
+ hidden_states,
516
+ next_decoder_cache,
517
+ all_hidden_states,
518
+ all_self_attentions,
519
+ all_cross_attentions,
520
+ ]
521
+ if v is not None
522
+ )
523
+ return BaseModelOutputWithPastAndCrossAttentions(
524
+ last_hidden_state=hidden_states,
525
+ past_key_values=next_decoder_cache,
526
+ hidden_states=all_hidden_states,
527
+ attentions=all_self_attentions,
528
+ cross_attentions=all_cross_attentions,
529
+ )
530
+
531
+
532
+ class BertPooler(nn.Module):
533
+ def __init__(self, config):
534
+ super().__init__()
535
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
536
+ self.activation = nn.Tanh()
537
+
538
+ def forward(self, hidden_states):
539
+ # We "pool" the model by simply taking the hidden state corresponding
540
+ # to the first token.
541
+ first_token_tensor = hidden_states[:, 0]
542
+ pooled_output = self.dense(first_token_tensor)
543
+ pooled_output = self.activation(pooled_output)
544
+ return pooled_output
545
+
546
+
547
+ class BertPredictionHeadTransform(nn.Module):
548
+ def __init__(self, config):
549
+ super().__init__()
550
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
551
+ if isinstance(config.hidden_act, str):
552
+ self.transform_act_fn = ACT2FN[config.hidden_act]
553
+ else:
554
+ self.transform_act_fn = config.hidden_act
555
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
556
+
557
+ def forward(self, hidden_states):
558
+ hidden_states = self.dense(hidden_states)
559
+ hidden_states = self.transform_act_fn(hidden_states)
560
+ hidden_states = self.LayerNorm(hidden_states)
561
+ return hidden_states
562
+
563
+
564
+ class BertLMPredictionHead(nn.Module):
565
+ def __init__(self, config):
566
+ super().__init__()
567
+ self.transform = BertPredictionHeadTransform(config)
568
+
569
+ # The output weights are the same as the input embeddings, but there is
570
+ # an output-only bias for each token.
571
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
572
+
573
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
574
+
575
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
576
+ self.decoder.bias = self.bias
577
+
578
+ def forward(self, hidden_states):
579
+ hidden_states = self.transform(hidden_states)
580
+ hidden_states = self.decoder(hidden_states)
581
+ return hidden_states
582
+
583
+
584
+ class BertOnlyMLMHead(nn.Module):
585
+ def __init__(self, config):
586
+ super().__init__()
587
+ self.predictions = BertLMPredictionHead(config)
588
+
589
+ def forward(self, sequence_output):
590
+ prediction_scores = self.predictions(sequence_output)
591
+ return prediction_scores
592
+
593
+
594
+ class BertPreTrainedModel(PreTrainedModel):
595
+ """
596
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
597
+ models.
598
+ """
599
+
600
+ config_class = BertConfig
601
+ base_model_prefix = "bert"
602
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
603
+
604
+ def _init_weights(self, module):
605
+ """Initialize the weights"""
606
+ if isinstance(module, (nn.Linear, nn.Embedding)):
607
+ # Modified from original LAVIS implementation to add truncated normal.
608
+ # This matches the original Tensorflow implementation from Google.
609
+ nn.init.trunc_normal_(module.weight, std=self.config.initializer_range)
610
+ elif isinstance(module, nn.LayerNorm):
611
+ module.bias.data.zero_()
612
+ module.weight.data.fill_(1.0)
613
+ if isinstance(module, nn.Linear) and module.bias is not None:
614
+ module.bias.data.zero_()
615
+
616
+
617
+ class BertModel(BertPreTrainedModel):
618
+ """
619
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
620
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
621
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
622
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
623
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
624
+ input to the forward pass.
625
+ """
626
+
627
+ def __init__(self, config, add_pooling_layer=False):
628
+ super().__init__(config)
629
+ self.config = config
630
+
631
+ self.embeddings = BertEmbeddings(config)
632
+
633
+ self.encoder = BertEncoder(config)
634
+
635
+ self.pooler = BertPooler(config) if add_pooling_layer else None
636
+
637
+ self.init_weights()
638
+
639
+ def get_input_embeddings(self):
640
+ return self.embeddings.word_embeddings
641
+
642
+ def set_input_embeddings(self, value):
643
+ self.embeddings.word_embeddings = value
644
+
645
+ def _prune_heads(self, heads_to_prune):
646
+ """
647
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
648
+ class PreTrainedModel
649
+ """
650
+ for layer, heads in heads_to_prune.items():
651
+ self.encoder.layer[layer].attention.prune_heads(heads)
652
+
653
+ def get_extended_attention_mask(
654
+ self,
655
+ attention_mask: Tensor,
656
+ input_shape: Tuple[int],
657
+ device: device,
658
+ is_decoder: bool,
659
+ has_query: bool = False,
660
+ ) -> Tensor:
661
+ """
662
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
663
+
664
+ Arguments:
665
+ attention_mask (:obj:`torch.Tensor`):
666
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
667
+ input_shape (:obj:`Tuple[int]`):
668
+ The shape of the input to the model.
669
+ device: (:obj:`torch.device`):
670
+ The device of the input to the model.
671
+
672
+ Returns:
673
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
674
+ """
675
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
676
+ # ourselves in which case we just need to make it broadcastable to all heads.
677
+ if attention_mask.dim() == 3:
678
+ extended_attention_mask = attention_mask[:, None, :, :]
679
+ elif attention_mask.dim() == 2:
680
+ # Provided a padding mask of dimensions [batch_size, seq_length]
681
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
682
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
683
+ if is_decoder:
684
+ batch_size, seq_length = input_shape
685
+
686
+ seq_ids = torch.arange(seq_length, device=device)
687
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
688
+
689
+ # add a prefix ones mask to the causal mask
690
+ # causal and attention masks must have same type with pytorch version < 1.3
691
+ causal_mask = causal_mask.to(attention_mask.dtype)
692
+
693
+ if causal_mask.shape[1] < attention_mask.shape[1]:
694
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
695
+ if has_query: # UniLM style attention mask
696
+ causal_mask = torch.cat(
697
+ [
698
+ torch.zeros(
699
+ (batch_size, prefix_seq_len, seq_length),
700
+ device=device,
701
+ dtype=causal_mask.dtype,
702
+ ),
703
+ causal_mask,
704
+ ],
705
+ axis=1,
706
+ )
707
+ causal_mask = torch.cat(
708
+ [
709
+ torch.ones(
710
+ (batch_size, causal_mask.shape[1], prefix_seq_len),
711
+ device=device,
712
+ dtype=causal_mask.dtype,
713
+ ),
714
+ causal_mask,
715
+ ],
716
+ axis=-1,
717
+ )
718
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
719
+ else:
720
+ extended_attention_mask = attention_mask[:, None, None, :]
721
+ else:
722
+ raise ValueError(
723
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
724
+ input_shape, attention_mask.shape
725
+ )
726
+ )
727
+
728
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
729
+ # masked positions, this operation will create a tensor which is 0.0 for
730
+ # positions we want to attend and -10000.0 for masked positions.
731
+ # Since we are adding it to the raw scores before the softmax, this is
732
+ # effectively the same as removing these entirely.
733
+ extended_attention_mask = extended_attention_mask.to(dtype=attention_mask.dtype) # fp16 compatibility
734
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
735
+ return extended_attention_mask
736
+
737
+ def forward(
738
+ self,
739
+ input_ids=None,
740
+ attention_mask=None,
741
+ position_ids=None,
742
+ head_mask=None,
743
+ query_embeds=None,
744
+ encoder_hidden_states=None,
745
+ encoder_attention_mask=None,
746
+ past_key_values=None,
747
+ use_cache=None,
748
+ output_attentions=None,
749
+ output_hidden_states=None,
750
+ return_dict=None,
751
+ is_decoder=False,
752
+ ):
753
+ r"""
754
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
755
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
756
+ the model is configured as a decoder.
757
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
758
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
759
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
760
+ - 1 for tokens that are **not masked**,
761
+ - 0 for tokens that are **masked**.
762
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
763
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
764
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
765
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
766
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
767
+ use_cache (:obj:`bool`, `optional`):
768
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
769
+ decoding (see :obj:`past_key_values`).
770
+ """
771
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
772
+ output_hidden_states = (
773
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
774
+ )
775
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
776
+
777
+ # use_cache = use_cache if use_cache is not None else self.config.use_cache
778
+
779
+ if input_ids is None:
780
+ assert query_embeds is not None, "You have to specify query_embeds when input_ids is None"
781
+
782
+ # past_key_values_length
783
+ past_key_values_length = (
784
+ past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0
785
+ )
786
+
787
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
788
+
789
+ embedding_output = self.embeddings(
790
+ input_ids=input_ids,
791
+ position_ids=position_ids,
792
+ query_embeds=query_embeds,
793
+ past_key_values_length=past_key_values_length,
794
+ )
795
+
796
+ input_shape = embedding_output.size()[:-1]
797
+ batch_size, seq_length = input_shape
798
+ device = embedding_output.device
799
+
800
+ if attention_mask is None:
801
+ attention_mask = torch.ones(
802
+ ((batch_size, seq_length + past_key_values_length)), device=device, dtype=embedding_output.dtype
803
+ )
804
+
805
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
806
+ # ourselves in which case we just need to make it broadcastable to all heads.
807
+ if is_decoder:
808
+ extended_attention_mask = self.get_extended_attention_mask(
809
+ attention_mask,
810
+ input_ids.shape,
811
+ device,
812
+ is_decoder,
813
+ has_query=(query_embeds is not None),
814
+ )
815
+ else:
816
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device, is_decoder)
817
+
818
+ # If a 2D or 3D attention mask is provided for the cross-attention
819
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
820
+ if encoder_hidden_states is not None:
821
+ if type(encoder_hidden_states) == list:
822
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
823
+ else:
824
+ (
825
+ encoder_batch_size,
826
+ encoder_sequence_length,
827
+ _,
828
+ ) = encoder_hidden_states.size()
829
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
830
+
831
+ if type(encoder_attention_mask) == list:
832
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
833
+ elif encoder_attention_mask is None:
834
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device, dtype=attention_mask.dtype)
835
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
836
+ else:
837
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
838
+ else:
839
+ encoder_extended_attention_mask = None
840
+
841
+ # Prepare head mask if needed
842
+ # 1.0 in head_mask indicate we keep the head
843
+ # attention_probs has shape bsz x n_heads x N x N
844
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
845
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
846
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
847
+
848
+ encoder_outputs = self.encoder(
849
+ embedding_output,
850
+ attention_mask=extended_attention_mask,
851
+ head_mask=head_mask,
852
+ encoder_hidden_states=encoder_hidden_states,
853
+ encoder_attention_mask=encoder_extended_attention_mask,
854
+ past_key_values=past_key_values,
855
+ use_cache=use_cache,
856
+ output_attentions=output_attentions,
857
+ output_hidden_states=output_hidden_states,
858
+ return_dict=return_dict,
859
+ query_length=query_length,
860
+ )
861
+ sequence_output = encoder_outputs[0]
862
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
863
+
864
+ if not return_dict:
865
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
866
+
867
+ return BaseModelOutputWithPoolingAndCrossAttentions(
868
+ last_hidden_state=sequence_output,
869
+ pooler_output=pooled_output,
870
+ past_key_values=encoder_outputs.past_key_values,
871
+ hidden_states=encoder_outputs.hidden_states,
872
+ attentions=encoder_outputs.attentions,
873
+ cross_attentions=encoder_outputs.cross_attentions,
874
+ )
875
+
876
+
877
+ class BertLMHeadModel(BertPreTrainedModel, GenerationMixin):
878
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
879
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
880
+
881
+ def __init__(self, config):
882
+ super().__init__(config)
883
+
884
+ self.bert = BertModel(config, add_pooling_layer=False)
885
+ self.cls = BertOnlyMLMHead(config)
886
+
887
+ self.init_weights()
888
+
889
+ def get_output_embeddings(self):
890
+ return self.cls.predictions.decoder
891
+
892
+ def set_output_embeddings(self, new_embeddings):
893
+ self.cls.predictions.decoder = new_embeddings
894
+
895
+ def forward(
896
+ self,
897
+ input_ids=None,
898
+ attention_mask=None,
899
+ position_ids=None,
900
+ head_mask=None,
901
+ query_embeds=None,
902
+ encoder_hidden_states=None,
903
+ encoder_attention_mask=None,
904
+ labels=None,
905
+ past_key_values=None,
906
+ use_cache=True,
907
+ output_attentions=None,
908
+ output_hidden_states=None,
909
+ return_dict=None,
910
+ return_logits=False,
911
+ is_decoder=True,
912
+ reduction="mean",
913
+ ):
914
+ r"""
915
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
916
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
917
+ the model is configured as a decoder.
918
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
919
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
920
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
921
+ - 1 for tokens that are **not masked**,
922
+ - 0 for tokens that are **masked**.
923
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
924
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
925
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
926
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
927
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
928
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
929
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
930
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
931
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
932
+ use_cache (:obj:`bool`, `optional`):
933
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
934
+ decoding (see :obj:`past_key_values`).
935
+ Returns:
936
+ Example::
937
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
938
+ >>> import torch
939
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
940
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
941
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
942
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
943
+ >>> outputs = model(**inputs)
944
+ >>> prediction_logits = outputs.logits
945
+ """
946
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
947
+ if labels is not None:
948
+ use_cache = False
949
+ if past_key_values is not None:
950
+ query_embeds = None
951
+
952
+ outputs = self.bert(
953
+ input_ids,
954
+ attention_mask=attention_mask,
955
+ position_ids=position_ids,
956
+ head_mask=head_mask,
957
+ query_embeds=query_embeds,
958
+ encoder_hidden_states=encoder_hidden_states,
959
+ encoder_attention_mask=encoder_attention_mask,
960
+ past_key_values=past_key_values,
961
+ use_cache=use_cache,
962
+ output_attentions=output_attentions,
963
+ output_hidden_states=output_hidden_states,
964
+ return_dict=return_dict,
965
+ is_decoder=is_decoder,
966
+ )
967
+
968
+ sequence_output = outputs[0]
969
+ if query_embeds is not None:
970
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
971
+
972
+ prediction_scores = self.cls(sequence_output)
973
+
974
+ if return_logits:
975
+ return prediction_scores[:, :-1, :].contiguous()
976
+
977
+ lm_loss = None
978
+ if labels is not None:
979
+ # we are doing next-token prediction; shift prediction scores and input ids by one
980
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
981
+ labels = labels[:, 1:].contiguous()
982
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
983
+ lm_loss = loss_fct(
984
+ shifted_prediction_scores.view(-1, self.config.vocab_size),
985
+ labels.view(-1),
986
+ )
987
+ if reduction == "none":
988
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
989
+
990
+ if not return_dict:
991
+ output = (prediction_scores,) + outputs[2:]
992
+ return ((lm_loss,) + output) if lm_loss is not None else output
993
+
994
+ return CausalLMOutputWithCrossAttentions(
995
+ loss=lm_loss,
996
+ logits=prediction_scores,
997
+ past_key_values=outputs.past_key_values,
998
+ hidden_states=outputs.hidden_states,
999
+ attentions=outputs.attentions,
1000
+ cross_attentions=outputs.cross_attentions,
1001
+ )
1002
+
1003
+ def prepare_inputs_for_generation(self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs):
1004
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1005
+ if attention_mask is None:
1006
+ attention_mask = input_ids.new_ones(input_ids.shape)
1007
+ query_mask = input_ids.new_ones(query_embeds.shape[:-1])
1008
+ attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
1009
+
1010
+ # cut decoder_input_ids if past is used
1011
+ if past is not None:
1012
+ input_ids = input_ids[:, -1:]
1013
+
1014
+ return {
1015
+ "input_ids": input_ids,
1016
+ "query_embeds": query_embeds,
1017
+ "attention_mask": attention_mask,
1018
+ "past_key_values": past,
1019
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
1020
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
1021
+ "is_decoder": True,
1022
+ }
1023
+
1024
+ def _reorder_cache(self, past, beam_idx):
1025
+ reordered_past = ()
1026
+ for layer_past in past:
1027
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1028
+ return reordered_past
1029
+
1030
+
1031
+ def load_qformer(
1032
+ num_query_tokens: int,
1033
+ encoder_width: int,
1034
+ vocab_size: int = 30523,
1035
+ hidden_size: int = 768,
1036
+ cross_attention_freq: int = 2,
1037
+ base_model: Literal["bert-base-uncased", "bert-large-uncased"] = "bert-base-uncased",
1038
+ ) -> BertLMHeadModel:
1039
+ """Utility to load QFormer module.
1040
+
1041
+ Args:
1042
+ num_query_tokens (int): number of query tokens.
1043
+ encoder_width (int): vector length of visual encoder embeddings.
1044
+ hidden_size (int): vector length of BERT's attention blocks.
1045
+ cross_attention_freq (int): block frequency of visual cross-attention.
1046
+ base_model (str): Base text model for QFormer. Default `bert-base-uncased`.
1047
+
1048
+ Returns:
1049
+ `BertLMHeadModel` module.
1050
+ """
1051
+
1052
+ encoder_config = BertConfig.from_pretrained(base_model)
1053
+ encoder_config.encoder_width = encoder_width
1054
+ encoder_config.hidden_size = hidden_size
1055
+ encoder_config.add_cross_attention = True
1056
+ encoder_config.cross_attention_freq = cross_attention_freq
1057
+ encoder_config.query_length = num_query_tokens
1058
+ encoder_config.vocab_size = vocab_size
1059
+ qformer = BertLMHeadModel(encoder_config)
1060
+ return qformer
modeling_utils.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Misc functions and modules for Cosmos-Embed1."""
17
+
18
+ import functools
19
+ from logging import getLogger
20
+ from typing import Callable, Optional, Protocol
21
+
22
+ import torch
23
+ import torch.distributed as dist
24
+ import torch.nn as nn
25
+
26
+ logger = getLogger(__file__)
27
+
28
+
29
+ def get_rank(group: Optional[dist.ProcessGroup] = None) -> int:
30
+ """Get the rank (GPU device) of the worker.
31
+
32
+ Returns:
33
+ rank (int): The rank of the worker.
34
+ """
35
+ rank = 0
36
+ if dist.is_available() and dist.is_initialized():
37
+ rank = dist.get_rank(group)
38
+ return rank
39
+
40
+
41
+ def barrier() -> None:
42
+ """Barrier for all GPUs."""
43
+ if dist.is_available() and dist.is_initialized():
44
+ dist.barrier()
45
+
46
+
47
+ def rank0_first(func: Callable) -> Callable:
48
+ """Run the function on rank 0 first, then on other ranks."""
49
+
50
+ @functools.wraps(func)
51
+ def wrapper(*args, **kwargs): # noqa: ANN202
52
+ if get_rank() == 0:
53
+ result = func(*args, **kwargs)
54
+ barrier()
55
+ if get_rank() != 0:
56
+ result = func(*args, **kwargs)
57
+ return result
58
+
59
+ return wrapper
60
+
61
+
62
+ def add_docstring(docstring: str):
63
+ def decorator(func):
64
+ func.__doc__ = docstring
65
+ return func
66
+
67
+ return decorator
68
+
69
+
70
+ INIT_DOCSTRING = """
71
+ Constructor for encoding module.
72
+
73
+ Args:
74
+ embed_dim: size of embedding vectors, e.g. x.shape[3].
75
+ max_len: maximum length of temporal sequence, e.g. x.shape[1].
76
+ """
77
+
78
+ FORWARD_DOCSTRING = """
79
+ Forward function.
80
+
81
+ Args:
82
+ x (`torch.Tensor`): rank 4 tensor to add spatio-temporal encodings to.
83
+
84
+ Returns:
85
+ `torch.Tensor` of rank 4.
86
+ """
87
+
88
+
89
+ class EncodingProtocol(Protocol):
90
+ def __init__(self, embed_dim: int, max_len: int) -> None:
91
+ pass
92
+
93
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
94
+ pass
95
+
96
+
97
+ def interpolate_temp_pos_embed(temp_embed: torch.Tensor, num_frames: int) -> torch.Tensor:
98
+ """Linearly interpolates temporal encoding from `temp_embed.shape[0] to num_frames."""
99
+
100
+ temp_embed_resized = temp_embed.permute(1, 0).unsqueeze(0)
101
+ temp_embed_resized = nn.functional.interpolate(
102
+ temp_embed_resized,
103
+ size=(num_frames),
104
+ mode="linear",
105
+ align_corners=False,
106
+ )
107
+ return temp_embed_resized.squeeze(0).permute(1, 0)
108
+
109
+
110
+ class TemporalParameterEncoding(nn.Module, EncodingProtocol):
111
+ @add_docstring(INIT_DOCSTRING)
112
+ def __init__(self, embed_dim: int, max_len: int) -> None:
113
+ super().__init__()
114
+ self.embed_dim = embed_dim
115
+ self.max_len = max_len
116
+ self.temp_embed = nn.Parameter(torch.zeros(self.max_len, self.embed_dim))
117
+ nn.init.trunc_normal_(self.temp_embed, std=0.02)
118
+
119
+ @add_docstring(FORWARD_DOCSTRING)
120
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
121
+ _, t, _, _ = x.shape
122
+ if t != self.temp_embed.shape[0]:
123
+ logger.debug(f"Interpolating temporal encodings from {self.temp_embed.shape[0]} to {t}.")
124
+ temp_embed = interpolate_temp_pos_embed(self.temp_embed, t)
125
+ else:
126
+ temp_embed = self.temp_embed
127
+ temp_embed = temp_embed.unsqueeze(0).unsqueeze(2)
128
+ return x + temp_embed
129
+
130
+
131
+ def create_neighbor_weight_matrix(num_tokens: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
132
+ indices = torch.arange(num_tokens, dtype=dtype, device=device)
133
+ abs_diff = torch.abs(indices.unsqueeze(0) - indices.unsqueeze(1))
134
+ weights = 1.0 / (2.0**abs_diff)
135
+ return weights
136
+
137
+
138
+ def compute_t_adj(x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
139
+ return torch.einsum("bfnd,nk->bfkd", x, weights)
140
+
141
+
142
+ def token_propagation(x: torch.Tensor, num_tokens: int) -> torch.Tensor:
143
+ """Apply neighboring token propagation update."""
144
+ weights = create_neighbor_weight_matrix(num_tokens, x.device, x.dtype)
145
+ t_adj = compute_t_adj(x, weights)
146
+ return x + t_adj - t_adj.detach()
147
+
148
+
149
+ class NeighboringTokenPropagationEncoding(TemporalParameterEncoding):
150
+ """
151
+ Neighboring Token Propagation method inspired by Momentor (https://arxiv.org/abs/2402.11435)
152
+ """
153
+
154
+ @add_docstring(FORWARD_DOCSTRING)
155
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
156
+ _, t, q, _ = x.shape
157
+ if t != self.temp_embed.shape[0]:
158
+ logger.debug(f"Interpolating temporal encodings from {self.temp_embed.shape[0]} to {t}.")
159
+ temp_embed = interpolate_temp_pos_embed(self.temp_embed, t)
160
+ else:
161
+ temp_embed = self.temp_embed
162
+ temp_embed = temp_embed.unsqueeze(0).unsqueeze(2)
163
+
164
+ if self.training:
165
+ temp_embed = token_propagation(temp_embed, q)
166
+ return x + temp_embed
167
+
168
+
169
+ class EncodingFactory(nn.Module):
170
+ def __init__(self, encoding_type: str, embed_dim: int, max_len: int) -> None:
171
+ super().__init__()
172
+ fn = {
173
+ "temporal_parameter": TemporalParameterEncoding,
174
+ "neighboring_token_propagation": NeighboringTokenPropagationEncoding,
175
+ }[encoding_type]
176
+ self.encoding = fn(embed_dim=embed_dim, max_len=max_len)
177
+
178
+ @add_docstring(FORWARD_DOCSTRING)
179
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
180
+ return self.encoding(x)
modeling_vit.py ADDED
@@ -0,0 +1,696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Copyright (c) 2023, salesforce.com, inc.
17
+ # All rights reserved.
18
+ # SPDX-License-Identifier: BSD-3-Clause
19
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
20
+
21
+ """
22
+ EVA-CLIP backbone used in BLIP2.
23
+
24
+ Code adapted from:
25
+ https://github.com/salesforce/LAVIS/blob/main/lavis/models/eva_vit.py
26
+ """
27
+
28
+
29
+ import math
30
+ from functools import partial
31
+ from logging import getLogger
32
+ from typing import Any, Optional, Tuple, Union
33
+
34
+ import torch
35
+ import torch.nn as nn
36
+ import torch.nn.functional as F
37
+ import torch.utils.checkpoint as checkpoint
38
+
39
+ logger = getLogger(__file__)
40
+
41
+ TRANSFORMER_ENGINE_AVAILABLE = False
42
+ try:
43
+ import transformer_engine.pytorch as te
44
+ from transformer_engine.common.recipe import DelayedScaling, Format
45
+
46
+ TRANSFORMER_ENGINE_AVAILABLE = True
47
+ logger.info("Transformer Engine is available, can set `transformer_engine=True` in config " "for faster inference.")
48
+ except ImportError:
49
+ pass
50
+
51
+
52
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
53
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
54
+
55
+ From https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
56
+ """
57
+ if drop_prob == 0.0 or not training:
58
+ return x
59
+ keep_prob = 1 - drop_prob
60
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
61
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
62
+ if keep_prob > 0.0 and scale_by_keep:
63
+ random_tensor.div_(keep_prob)
64
+ return x * random_tensor
65
+
66
+
67
+ class DropPath(nn.Module):
68
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
69
+
70
+ def __init__(self, drop_prob: float) -> None:
71
+ super().__init__()
72
+ self.drop_prob = drop_prob
73
+
74
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
75
+ return drop_path(x, self.drop_prob, self.training)
76
+
77
+ def extra_repr(self) -> str:
78
+ return "p={}".format(self.drop_prob)
79
+
80
+
81
+ class Mlp(nn.Module):
82
+ def __init__(
83
+ self,
84
+ in_features: int,
85
+ hidden_features: Optional[int] = None,
86
+ out_features: Optional[int] = None,
87
+ act_layer=nn.GELU,
88
+ drop: float = 0.0,
89
+ transformer_engine: bool = False,
90
+ ) -> None:
91
+ super().__init__()
92
+ out_features = out_features or in_features
93
+ hidden_features = hidden_features or in_features
94
+ fn = te.Linear if transformer_engine else nn.Linear
95
+ self.fc1 = fn(in_features, hidden_features)
96
+ self.act = act_layer()
97
+ self.fc2 = fn(hidden_features, out_features)
98
+ self.drop = nn.Dropout(drop)
99
+
100
+ def forward(self, x):
101
+ x = self.fc1(x)
102
+ x = self.act(x)
103
+ x = self.fc2(x)
104
+ x = self.drop(x)
105
+ return x
106
+
107
+
108
+ class Attention(nn.Module):
109
+ def __init__(
110
+ self,
111
+ dim,
112
+ num_heads=8,
113
+ qkv_bias=False,
114
+ qk_scale=None,
115
+ attn_drop=0.0,
116
+ proj_drop=0.0,
117
+ window_size=None,
118
+ attn_head_dim=None,
119
+ **kwargs,
120
+ ):
121
+ super().__init__()
122
+ self.num_heads = num_heads
123
+ head_dim = dim // num_heads
124
+ if attn_head_dim is not None:
125
+ head_dim = attn_head_dim
126
+ all_head_dim = head_dim * self.num_heads
127
+ self.scale = qk_scale or head_dim**-0.5
128
+
129
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias)
130
+
131
+ if window_size:
132
+ self.window_size = window_size
133
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
134
+ self.relative_position_bias_table = nn.Parameter(
135
+ torch.zeros(self.num_relative_distance, num_heads)
136
+ ) # 2*Wh-1 * 2*Ww-1, nH
137
+ # cls to token & token 2 cls & cls to cls
138
+
139
+ # get pair-wise relative position index for each token inside the window
140
+ coords_h = torch.arange(window_size[0])
141
+ coords_w = torch.arange(window_size[1])
142
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
143
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
144
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
145
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
146
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
147
+ relative_coords[:, :, 1] += window_size[1] - 1
148
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
149
+ relative_position_index = torch.zeros(
150
+ size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
151
+ )
152
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
153
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
154
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
155
+ relative_position_index[0, 0] = self.num_relative_distance - 1
156
+
157
+ self.register_buffer("relative_position_index", relative_position_index)
158
+ else:
159
+ self.window_size = None
160
+ self.relative_position_bias_table = None
161
+ self.relative_position_index = None
162
+
163
+ self.attn_drop = nn.Dropout(attn_drop)
164
+ self.proj = nn.Linear(all_head_dim, dim)
165
+ self.proj_drop = nn.Dropout(proj_drop)
166
+
167
+ def forward(self, x, rel_pos_bias=None):
168
+ B, N, C = x.shape
169
+ qkv = self.qkv(x)
170
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
171
+ q, k, v = qkv[0], qkv[1], qkv[2]
172
+
173
+ q = q * self.scale
174
+ attn = q @ k.transpose(-2, -1)
175
+
176
+ if self.relative_position_bias_table is not None:
177
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
178
+ self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1
179
+ ) # Wh*Ww,Wh*Ww,nH
180
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
181
+ attn = attn + relative_position_bias.unsqueeze(0)
182
+
183
+ if rel_pos_bias is not None:
184
+ attn = attn + rel_pos_bias
185
+
186
+ attn = attn.softmax(dim=-1)
187
+ attn = self.attn_drop(attn)
188
+
189
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
190
+ x = self.proj(x)
191
+ x = self.proj_drop(x)
192
+ return x
193
+
194
+
195
+ class TransformerEngineAttention(nn.Module):
196
+ def __init__(
197
+ self,
198
+ dim: int,
199
+ num_heads: int = 8,
200
+ qkv_bias: bool = False,
201
+ qk_scale: Optional[float] = None,
202
+ attn_drop: float = 0.0,
203
+ proj_drop: float = 0.0,
204
+ window_size: Optional[int] = None,
205
+ attn_head_dim: Optional[int] = None,
206
+ checkpoint_attention: bool = False,
207
+ ):
208
+ super().__init__()
209
+ self.num_heads = num_heads
210
+ self.checkpoint_attention = checkpoint_attention
211
+ head_dim = dim // num_heads
212
+ if attn_head_dim is not None:
213
+ head_dim = attn_head_dim
214
+ all_head_dim = head_dim * self.num_heads
215
+ self.scale = qk_scale or head_dim**-0.5
216
+
217
+ # QKV projection
218
+ self.qkv = te.Linear(dim, all_head_dim * 3, bias=qkv_bias)
219
+
220
+ if window_size:
221
+ raise NotImplementedError("`window_size` not implemented for TE!")
222
+
223
+ self.te_attn = te.DotProductAttention(
224
+ num_attention_heads=num_heads,
225
+ kv_channels=head_dim,
226
+ attention_dropout=attn_drop,
227
+ qkv_format="bshd",
228
+ softmax_scale=self.scale,
229
+ attn_mask_type="no_mask",
230
+ )
231
+
232
+ # output projection + dropout
233
+ self.proj = te.Linear(all_head_dim, dim)
234
+ self.proj_drop = nn.Dropout(proj_drop)
235
+
236
+ def forward(self, x: torch.Tensor, rel_pos_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
237
+ """
238
+ x: [B, N, C]
239
+ rel_pos_bias (optional): tensor of shape [num_heads, N, N]
240
+ """
241
+ B, N, _ = x.shape
242
+ qkv = self.qkv(x)
243
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4)
244
+ q, k, v = qkv[0], qkv[1], qkv[2] # BNHC format
245
+
246
+ if rel_pos_bias is not None:
247
+ raise NotImplementedError("`rel_pos_bias` not implemented for TE!")
248
+
249
+ # run TE's fused attention
250
+ y = self.te_attn(q, k, v, checkpoint_core_attention=self.checkpoint_attention)
251
+
252
+ # final proj + dropout
253
+ return self.proj_drop(self.proj(y))
254
+
255
+
256
+ class Block(nn.Module):
257
+ def __init__(
258
+ self,
259
+ dim,
260
+ num_heads,
261
+ mlp_ratio=4.0,
262
+ qkv_bias=False,
263
+ qk_scale=None,
264
+ drop=0.0,
265
+ attn_drop=0.0,
266
+ drop_path=0.0,
267
+ init_values=None,
268
+ act_layer=nn.GELU,
269
+ norm_layer=nn.LayerNorm,
270
+ window_size=None,
271
+ attn_head_dim=None,
272
+ transformer_engine: bool = False,
273
+ checkpoint_attention: bool = False,
274
+ ):
275
+ super().__init__()
276
+ self.transformer_engine = transformer_engine
277
+ self.window_size = window_size
278
+ self.checkpoint_attention = checkpoint_attention
279
+
280
+ if checkpoint_attention and not transformer_engine:
281
+ raise ValueError("`checkpoint_attention` needs `transformer_engine`!")
282
+
283
+ self.norm1 = norm_layer(dim)
284
+ attn_fn = TransformerEngineAttention if transformer_engine else Attention
285
+ self.attn = attn_fn(
286
+ dim,
287
+ num_heads=num_heads,
288
+ qkv_bias=qkv_bias,
289
+ qk_scale=qk_scale,
290
+ attn_drop=attn_drop,
291
+ proj_drop=drop,
292
+ window_size=window_size,
293
+ attn_head_dim=attn_head_dim,
294
+ checkpoint_attention=checkpoint_attention,
295
+ )
296
+
297
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
298
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
299
+ self.norm2 = norm_layer(dim)
300
+ mlp_hidden_dim = int(dim * mlp_ratio)
301
+ self.mlp = Mlp(
302
+ in_features=dim,
303
+ hidden_features=mlp_hidden_dim,
304
+ act_layer=act_layer,
305
+ drop=drop,
306
+ transformer_engine=transformer_engine,
307
+ )
308
+
309
+ if init_values is not None and init_values > 0:
310
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
311
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
312
+ else:
313
+ self.gamma_1, self.gamma_2 = None, None
314
+
315
+ def forward(self, x, rel_pos_bias=None):
316
+ if self.gamma_1 is None:
317
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
318
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
319
+ else:
320
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
321
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
322
+ return x
323
+
324
+
325
+ class PatchEmbed(nn.Module):
326
+ """Image to Patch Embedding"""
327
+
328
+ def __init__(
329
+ self,
330
+ img_size: Union[int, Tuple[int, int]] = 224,
331
+ patch_size: Union[int, Tuple[int, int]] = 16,
332
+ in_chans: int = 3,
333
+ embed_dim: int = 768,
334
+ ):
335
+ super().__init__()
336
+ img_size = (img_size, img_size) if isinstance(img_size, int) else img_size
337
+ patch_size = (patch_size, patch_size) if isinstance(patch_size, int) else patch_size
338
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
339
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
340
+ self.img_size = img_size
341
+ self.patch_size = patch_size
342
+ self.num_patches = num_patches
343
+
344
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
345
+
346
+ def forward(self, x, **kwargs):
347
+ B, C, H, W = x.shape
348
+ assert (
349
+ H == self.img_size[0] and W == self.img_size[1]
350
+ ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
351
+ x = self.proj(x).flatten(2).transpose(1, 2)
352
+ return x
353
+
354
+
355
+ class RelativePositionBias(nn.Module):
356
+ def __init__(self, window_size, num_heads):
357
+ super().__init__()
358
+ self.window_size = window_size
359
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
360
+ self.relative_position_bias_table = nn.Parameter(
361
+ torch.zeros(self.num_relative_distance, num_heads)
362
+ ) # 2*Wh-1 * 2*Ww-1, nH
363
+ # cls to token & token 2 cls & cls to cls
364
+
365
+ # get pair-wise relative position index for each token inside the window
366
+ coords_h = torch.arange(window_size[0])
367
+ coords_w = torch.arange(window_size[1])
368
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
369
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
370
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
371
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
372
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
373
+ relative_coords[:, :, 1] += window_size[1] - 1
374
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
375
+ relative_position_index = torch.zeros(
376
+ size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
377
+ )
378
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
379
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
380
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
381
+ relative_position_index[0, 0] = self.num_relative_distance - 1
382
+
383
+ self.register_buffer("relative_position_index", relative_position_index)
384
+
385
+ def forward(self):
386
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
387
+ self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1
388
+ ) # Wh*Ww,Wh*Ww,nH
389
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
390
+
391
+
392
+ class VisionTransformer(nn.Module):
393
+ """Vision Transformer with support for patch or hybrid CNN input stage"""
394
+
395
+ def __init__(
396
+ self,
397
+ img_size=224,
398
+ patch_size=16,
399
+ in_chans=3,
400
+ num_classes=1000,
401
+ embed_dim=768,
402
+ depth=12,
403
+ num_heads=12,
404
+ mlp_ratio=4.0,
405
+ qkv_bias=False,
406
+ qk_scale=None,
407
+ drop_rate=0.0,
408
+ attn_drop_rate=0.0,
409
+ drop_path_rate=0.0,
410
+ norm_layer=nn.LayerNorm,
411
+ init_values=None,
412
+ use_abs_pos_emb=True,
413
+ use_rel_pos_bias=False,
414
+ use_shared_rel_pos_bias=False,
415
+ use_mean_pooling=True,
416
+ init_scale=0.001,
417
+ checkpoint_activations: bool = False,
418
+ checkpoint_attention: bool = False,
419
+ transformer_engine: bool = False,
420
+ use_fp8: bool = False,
421
+ ):
422
+ super().__init__()
423
+ self.image_size = img_size
424
+ self.patch_size = patch_size
425
+ self.num_classes = num_classes
426
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
427
+ self.transformer_engine = transformer_engine
428
+ self.use_fp8 = use_fp8
429
+ self.fp8_recipe = None
430
+
431
+ if use_fp8 and not transformer_engine:
432
+ raise ValueError("`transformer_engine` must be enabled for `use_fp8`.")
433
+ if use_fp8:
434
+ # FP8 Recipe: Hybrid E4M3 forward, E5M2 backward
435
+ self.fp8_recipe = DelayedScaling(fp8_format=Format.HYBRID, amax_history_len=16, amax_compute_algo="max")
436
+
437
+ self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
438
+ num_patches = self.patch_embed.num_patches
439
+
440
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
441
+ if use_abs_pos_emb:
442
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
443
+ else:
444
+ self.pos_embed = None
445
+ self.pos_drop = nn.Dropout(p=drop_rate)
446
+
447
+ if use_shared_rel_pos_bias:
448
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
449
+ else:
450
+ self.rel_pos_bias = None
451
+ self.checkpoint_activations = checkpoint_activations
452
+ self.checkpoint_attention = checkpoint_attention
453
+
454
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
455
+ self.use_rel_pos_bias = use_rel_pos_bias
456
+ self.blocks = nn.ModuleList(
457
+ [
458
+ Block(
459
+ dim=embed_dim,
460
+ num_heads=num_heads,
461
+ mlp_ratio=mlp_ratio,
462
+ qkv_bias=qkv_bias,
463
+ qk_scale=qk_scale,
464
+ drop=drop_rate,
465
+ attn_drop=attn_drop_rate,
466
+ drop_path=dpr[i],
467
+ norm_layer=norm_layer,
468
+ init_values=init_values,
469
+ window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
470
+ transformer_engine=transformer_engine,
471
+ checkpoint_attention=self.checkpoint_attention,
472
+ )
473
+ for i in range(depth)
474
+ ]
475
+ )
476
+
477
+ if self.pos_embed is not None:
478
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
479
+ nn.init.trunc_normal_(self.cls_token, std=0.02)
480
+
481
+ self.apply(self._init_weights)
482
+ self.fix_init_weight()
483
+
484
+ def fix_init_weight(self):
485
+ def rescale(param, layer_id):
486
+ param.div_(math.sqrt(2.0 * layer_id))
487
+
488
+ for layer_id, layer in enumerate(self.blocks):
489
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
490
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
491
+
492
+ def _init_weights(self, m):
493
+ if isinstance(m, nn.Linear):
494
+ nn.init.trunc_normal_(m.weight, std=0.02)
495
+ if isinstance(m, nn.Linear) and m.bias is not None:
496
+ nn.init.constant_(m.bias, 0)
497
+ elif isinstance(m, nn.LayerNorm):
498
+ nn.init.constant_(m.bias, 0)
499
+ nn.init.constant_(m.weight, 1.0)
500
+
501
+ def get_classifier(self):
502
+ return self.head
503
+
504
+ def reset_classifier(self, num_classes, global_pool=""):
505
+ self.num_classes = num_classes
506
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
507
+
508
+ def forward_features(self, x):
509
+ if self.transformer_engine and self.use_fp8:
510
+ with te.fp8_autocast(enabled=True, fp8_recipe=self.fp8_recipe):
511
+ return self._forward_uncast(x)
512
+ return self._forward_uncast(x)
513
+
514
+ def _forward_uncast(self, x):
515
+ x = self.patch_embed(x)
516
+ batch_size, seq_len, _ = x.size()
517
+
518
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
519
+ x = torch.cat((cls_tokens, x), dim=1)
520
+ if self.pos_embed is not None:
521
+ x = x + self.pos_embed
522
+ x = self.pos_drop(x)
523
+
524
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
525
+ for blk in self.blocks:
526
+ if self.checkpoint_activations:
527
+ x = checkpoint.checkpoint(blk, x, rel_pos_bias)
528
+ else:
529
+ x = blk(x, rel_pos_bias)
530
+ return x
531
+
532
+ def forward(self, x):
533
+ x = self.forward_features(x)
534
+ return x
535
+
536
+ def get_intermediate_layers(self, x):
537
+ x = self.patch_embed(x)
538
+ batch_size, seq_len, _ = x.size()
539
+
540
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
541
+ x = torch.cat((cls_tokens, x), dim=1)
542
+ if self.pos_embed is not None:
543
+ x = x + self.pos_embed
544
+ x = self.pos_drop(x)
545
+
546
+ features = []
547
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
548
+ for blk in self.blocks:
549
+ x = blk(x, rel_pos_bias)
550
+ features.append(x)
551
+
552
+ return features
553
+
554
+ def get_num_layer(self, var_name=""):
555
+ if var_name in ("cls_token", "mask_token", "pos_embed"):
556
+ return 0
557
+ elif var_name.startswith("patch_embed"):
558
+ return 0
559
+ elif var_name.startswith("rel_pos_bias"):
560
+ return len(self.blocks) - 1
561
+ elif var_name.startswith("blocks"):
562
+ layer_id = int(var_name.split(".")[1])
563
+ return layer_id + 1
564
+ else:
565
+ return len(self.blocks)
566
+
567
+
568
+ def interpolate_pos_embed(
569
+ pos_embed_key: str,
570
+ num_patches: int,
571
+ patch_embed_shape: torch.Size,
572
+ checkpoint_model: dict[str, torch.Tensor],
573
+ target_h: int = None,
574
+ target_w: int = None,
575
+ ) -> None:
576
+ if pos_embed_key in checkpoint_model:
577
+ pos_embed_checkpoint = checkpoint_model[pos_embed_key].float()
578
+ embedding_size = pos_embed_checkpoint.shape[-1]
579
+ num_extra_tokens = patch_embed_shape - num_patches
580
+ # height (== width) for the checkpoint position embedding
581
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
582
+
583
+ # If target dimensions are provided, use them; otherwise assume square
584
+ if target_h is not None and target_w is not None:
585
+ new_h, new_w = target_h, target_w
586
+ else:
587
+ # height (== width) for the new position embedding (square assumption)
588
+ new_size = int(num_patches**0.5)
589
+ new_h, new_w = new_size, new_size
590
+
591
+ # class_token and dist_token are kept unchanged
592
+ if orig_size * orig_size != new_h * new_w:
593
+ logger.info("Positional interpolation from %dx%d to %dx%d" % (orig_size, orig_size, new_h, new_w))
594
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
595
+ # only the position tokens are interpolated
596
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
597
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
598
+ pos_tokens = torch.nn.functional.interpolate(
599
+ pos_tokens, size=(new_h, new_w), mode="bicubic", align_corners=False
600
+ )
601
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
602
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
603
+ checkpoint_model[pos_embed_key] = new_pos_embed
604
+
605
+
606
+ class PositionalEmbeddingHook:
607
+ def __init__(self, pos_embed_name, num_patches, patch_embed_shape, target_h=None, target_w=None):
608
+ self.pos_embed_name = pos_embed_name
609
+ self.num_patches = num_patches
610
+ self.patch_embed_shape = patch_embed_shape
611
+ self.target_h = target_h
612
+ self.target_w = target_w
613
+
614
+ def __call__(self, state_dict, prefix, *args, **kwargs) -> None:
615
+ logger.info("Calling `PositionalEmbeddingHook`")
616
+ pos_embed_key = f"{prefix}{self.pos_embed_name}"
617
+ interpolate_pos_embed(
618
+ pos_embed_key, self.num_patches, self.patch_embed_shape, state_dict, self.target_h, self.target_w
619
+ )
620
+
621
+
622
+ class EvaViTG(VisionTransformer):
623
+ def __init__(
624
+ self,
625
+ img_size: Union[int, Tuple[int, int]] = 224,
626
+ drop_path_rate: float = 0.4,
627
+ pretrained: bool = False,
628
+ checkpoint_path: Optional[str] = None,
629
+ checkpoint_activations: bool = False,
630
+ checkpoint_attention: bool = False,
631
+ transformer_engine: bool = False,
632
+ use_fp8: bool = False,
633
+ **kwargs: Any,
634
+ ) -> None:
635
+ if not TRANSFORMER_ENGINE_AVAILABLE and transformer_engine:
636
+ raise ValueError(
637
+ "TransformerEngine is not available, "
638
+ "please install transformer-engine or set `transformer_engine=False` in config."
639
+ )
640
+ if use_fp8 and not transformer_engine:
641
+ raise ValueError("`transformer_engine` must be enabled for `use_fp8`.")
642
+ super().__init__(
643
+ img_size=img_size,
644
+ patch_size=14,
645
+ use_mean_pooling=False,
646
+ embed_dim=1408,
647
+ depth=39,
648
+ num_heads=1408 // 88,
649
+ mlp_ratio=4.3637,
650
+ qkv_bias=True,
651
+ drop_path_rate=drop_path_rate,
652
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
653
+ checkpoint_activations=checkpoint_activations,
654
+ checkpoint_attention=checkpoint_attention,
655
+ transformer_engine=transformer_engine,
656
+ use_fp8=use_fp8,
657
+ )
658
+ self.checkpoint_path = checkpoint_path
659
+
660
+ # compatibility with pre-trained checkpoints
661
+ self.register_pre_hooks()
662
+
663
+ # load pre-trained checkpoints
664
+ if pretrained:
665
+ self.load_checkpoint()
666
+
667
+ def load_checkpoint(self) -> None:
668
+ logger.info(f"Loading checkpoint from {self.checkpoint_path}")
669
+ state_dict = torch.load(self.checkpoint_path, map_location="cpu")
670
+ incompatible_keys = self.load_state_dict(state_dict, strict=False)
671
+ logger.info(f"Incompatible keys: {incompatible_keys}")
672
+ logger.info(f"Loaded visual encoder {type(self)} with state dict from {self.checkpoint_path}")
673
+
674
+ def register_pre_hooks(self) -> None:
675
+ """Register positional embedding interpolation when loading pre-trained checkpoints using different resolution."""
676
+ # Calculate target patch dimensions for non-square support
677
+ patch_h = self.patch_embed.patch_shape[0]
678
+ patch_w = self.patch_embed.patch_shape[1]
679
+
680
+ embed_hook = PositionalEmbeddingHook(
681
+ pos_embed_name="pos_embed",
682
+ num_patches=self.patch_embed.num_patches,
683
+ patch_embed_shape=self.pos_embed.shape[-2],
684
+ target_h=patch_h,
685
+ target_w=patch_w,
686
+ )
687
+ self._register_load_state_dict_pre_hook(embed_hook)
688
+
689
+ def _initialize_weights(self, m):
690
+ if isinstance(m, nn.Linear):
691
+ nn.init.trunc_normal_(m.weight, std=0.02)
692
+ if isinstance(m, nn.Linear) and m.bias is not None:
693
+ nn.init.constant_(m.bias, 0)
694
+ elif isinstance(m, nn.LayerNorm):
695
+ nn.init.constant_(m.bias, 0)
696
+ nn.init.constant_(m.weight, 1.0)
preprocessing_embed1.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """
17
+ Processor class for Cosmos-Embed1
18
+ """
19
+
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+ import torchvision
25
+ from transformers import AutoProcessor, BatchFeature
26
+ from transformers.processing_utils import ProcessorMixin
27
+ from transformers.utils import TensorType
28
+
29
+ from .configuration_embed1 import CosmosEmbed1Config
30
+
31
+
32
+ class CosmosEmbed1Processor(ProcessorMixin):
33
+ r"""
34
+ Constructs a processor which wraps a BertTokenizer tokenizer and a fast video resize function.
35
+
36
+ Args:
37
+ tokenizer ([`BertTokenizerFast`], *optional*):
38
+ The tokenizer is a required input for text processing.
39
+ config ([`CosmosEmbed1Config`], *optional*):
40
+ Needed for processing options.
41
+ """
42
+
43
+ attributes = ["tokenizer"]
44
+ tokenizer_class = ("BertTokenizer", "BertTokenizerFast")
45
+ config_class = CosmosEmbed1Config
46
+ chat_template = None
47
+
48
+ def __init__(
49
+ self,
50
+ tokenizer=None,
51
+ resolution: Union[int, Tuple[int, int]] = 448,
52
+ num_video_frames: int = 8,
53
+ max_txt_len: int = 128,
54
+ **kwargs,
55
+ ) -> None:
56
+ super().__init__(tokenizer, **kwargs)
57
+ self.resolution = resolution
58
+ self.num_video_frames = num_video_frames
59
+ self.max_txt_len = max_txt_len
60
+
61
+ def __call__(
62
+ self,
63
+ text: Optional[Union[str, List[str]]] = None,
64
+ videos: Optional[Union[np.ndarray, torch.Tensor]] = None,
65
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
66
+ resolution: Union[int, Tuple[int, int]] = None,
67
+ num_video_frames: int = None,
68
+ max_txt_len: int = None,
69
+ **kwargs,
70
+ ) -> BatchFeature:
71
+ inputs = {}
72
+
73
+ if text is not None:
74
+ max_txt_len = max_txt_len if max_txt_len is not None else self.max_txt_len
75
+ tokenized = self.tokenizer(
76
+ text, return_tensors="pt", padding="max_length", truncation=True, max_length=max_txt_len, **kwargs
77
+ )
78
+ inputs["input_ids"] = tokenized.input_ids
79
+ inputs["attention_mask"] = tokenized.attention_mask.float()
80
+
81
+ if videos is not None:
82
+ if isinstance(videos, np.ndarray):
83
+ videos = torch.from_numpy(videos)
84
+ if not isinstance(videos, torch.Tensor) or videos.ndim != 5:
85
+ raise ValueError("Processor expects a numpy or torch tensor of shape BTCHW from [0-255].")
86
+ resolution = resolution if resolution is not None else self.resolution
87
+ if isinstance(resolution, int):
88
+ resolution = (resolution, resolution)
89
+ _, t, c, h, w = videos.shape
90
+ if c != 3:
91
+ raise ValueError(f"Expected tensor of shape BTCHW with RGB channels, got channel size {c}.")
92
+ num_video_frames = num_video_frames if num_video_frames is not None else self.num_video_frames
93
+ if t != num_video_frames:
94
+ raise ValueError(f"Expected tensor of shape BTCHW with {num_video_frames} frames, got {t}.")
95
+ if h != resolution[0] or w != resolution[1]:
96
+ videos = resize_video(videos, resolution)
97
+ if videos.dtype == torch.uint8:
98
+ videos = videos.float()
99
+ inputs["videos"] = videos / 255.0
100
+
101
+ if not inputs:
102
+ raise ValueError("Must pass either `text` or `videos` argument to __call__ function.")
103
+
104
+ return BatchFeature(inputs, tensor_type=return_tensors)
105
+
106
+
107
+ def resize_video(video: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
108
+ """Resize a video tensor (B, T, C, H, W) to a new height/width.
109
+
110
+ Args:
111
+ video (torch.Tensor): (B, T, C, H, W) uint8 or float32.
112
+ size (tuple): target (H', W') size.
113
+ Returns:
114
+ torch.Tensor: resized video of shape (B, T, C, H', W')
115
+ """
116
+ h, w = size
117
+ B, T, C, H, W = video.shape
118
+ video = video.view(B * T, C, H, W)
119
+ resize = torchvision.transforms.Resize(
120
+ (h, w),
121
+ antialias=True,
122
+ interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
123
+ )
124
+ video = resize(video)
125
+ new_H, new_W = video.shape[-2:]
126
+ video = video.view(B, T, C, new_H, new_W)
127
+ return video
128
+
129
+
130
+ AutoProcessor.register(CosmosEmbed1Config, CosmosEmbed1Processor)
131
+
132
+
133
+ __all__ = ["CosmosEmbed1Processor"]
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"do_lower_case": true, "model_max_length": 512}
vocab.txt ADDED
The diff for this file is too large to render. See raw diff