Ligeng-Zhu commited on
Commit
d8c0285
·
verified ·
1 Parent(s): 08812b1

Upload files with `vila-upload`.

Browse files

Upload conversation.py
Upload media_encoder.py
Upload media.py
Upload utils.py
Upload modeling_vila.py
Upload main.py
Upload constants.py
Upload config.json
Upload README.md
Upload configuration_vila.py
Upload builder.py
Upload base_projector.py
Upload trainer_state.json
Upload mm_utils.py
Upload tokenizer_utils.py
Upload siglip_encoder.py
Upload llm/added_tokens.json
Upload llm/generation_config.json
Upload llm/model-00002-of-00004.safetensors
Upload llm/model-00004-of-00004.safetensors
Upload llm/model-00001-of-00004.safetensors
Upload llm/merges.txt
Upload llm/special_tokens_map.json
Upload llm/config.json
Upload llm/vocab.json
Upload llm/tokenizer_config.json
Upload llm/model-00003-of-00004.safetensors
Upload llm/model.safetensors.index.json
Upload mm_projector/config.json
Upload mm_projector/model.safetensors
Upload vision_tower/config.json
Upload vision_tower/preprocessor_config.json
Upload vision_tower/model.safetensors

README.md ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc-by-nc-4.0
3
+ library_name: transformers
4
+ pipeline_tag: text-generation
5
+ tags:
6
+ - NVILA
7
+ - VLM
8
+ ---
9
+
10
+ # VILA Model Card
11
+
12
+ ## Model details
13
+
14
+ **Model type:**
15
+ NVILA is a visual language model (VLM) pretrained with interleaved image-text data at scale, enabling multi-image VLM. Visual language models (VLMs) have made significant advances in accuracy in recent years. However, their efficiency has received much less attention. This paper introduces NVILA, a family of open VLMs designed to optimize both efficiency and accuracy. Building on top of VILA, we improve its model architecture by first scaling up the spatial and temporal resolutions, and then compressing visual tokens. This "scale-then-compress" approach enables NVILA to efficiently process high-resolution images and long videos. We also conduct a systematic investigation to enhance the efficiency of NVILA throughout its entire lifecycle, from training and fine-tuning to deployment. NVILA matches or surpasses the accuracy of many leading open and proprietary VLMs across a wide range of image and video benchmarks. At the same time, it reduces training costs by 4.5X, fine-tuning memory usage by 3.4X, pre-filling latency by 1.6-2.2X, and decoding latency by 1.2-2.8X. We will soon make our code and models available to facilitate reproducibility.
16
+
17
+ **Model date:**
18
+ NVILA was trained in Nov 2024.
19
+
20
+ **Paper or resources for more information:**
21
+ https://github.com/NVLabs/VILA
22
+
23
+ ```
24
+ @misc{liu2024nvila,
25
+ title={NVILA: Efficient Frontier Visual Language Models},
26
+ author={Zhijian Liu and Ligeng Zhu and Baifeng Shi and Zhuoyang Zhang and Yuming Lou and Shang Yang and Haocheng Xi and Shiyi Cao and Yuxian Gu and Dacheng Li and Xiuyu Li and Yunhao Fang and Yukang Chen and Cheng-Yu Hsieh and De-An Huang and An-Chieh Cheng and Vishwesh Nath and Jinyi Hu and Sifei Liu and Ranjay Krishna and Daguang Xu and Xiaolong Wang and Pavlo Molchanov and Jan Kautz and Hongxu Yin and Song Han and Yao Lu},
27
+ year={2024},
28
+ eprint={2412.04468},
29
+ archivePrefix={arXiv},
30
+ primaryClass={cs.CV},
31
+ url={https://arxiv.org/abs/2412.04468},
32
+ }
33
+ ```
34
+
35
+ ## License
36
+ - The code is released under the Apache 2.0 license as found in the [LICENSE](./LICENSE) file.
37
+ - The pretrained weights are released under the [CC-BY-NC-SA-4.0 license](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en).
38
+ - The service is a research preview intended for non-commercial use only, and is subject to the following licenses and terms:
39
+ - [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI
40
+ - [Dataset Licenses](https://github.com/Efficient-Large-Model/VILA/blob/main/data_prepare/LICENSE) for each one used during training.
41
+
42
+ **Where to send questions or comments about the model:**
43
+ https://github.com/NVLabs/VILA/issues
44
+
45
+ ## Intended use
46
+ **Primary intended uses:**
47
+ The primary use of VILA is research on large multimodal models and chatbots.
48
+
49
+ **Primary intended users:**
50
+ The primary intended users of the model are researchers and hobbyists in computer vision, natural language processing, machine learning, and artificial intelligence.
51
+
52
+ ## Input:
53
+ **Input Type:** Image, Video, Text
54
+ **Input Format:** Red, Green, Blue; MP4 ;String
55
+ **Input Parameters:** 2D, 3D
56
+
57
+ ## Output:
58
+ **Output Type:** Text
59
+ **Output Format:** String
60
+
61
+ **Supported Hardware Microarchitecture Compatibility:**
62
+ * Ampere
63
+ * Jetson
64
+ * Hopper
65
+ * Lovelace
66
+
67
+ **[Preferred/Supported] Operating System(s):** <br>
68
+ Linux
69
+
70
+ ## Training dataset
71
+ See [Dataset Preparation](https://arxiv.org/abs/2412.04468) for more details.
72
+
73
+ ** Data Collection Method by dataset
74
+ * [Hybrid: Automated, Human]
75
+
76
+ ** Labeling Method by dataset
77
+ * [Hybrid: Automated, Human]
78
+
79
+ ## Inference:
80
+ **Engine:** [Tensor(RT), Triton, Or List Other Here]
81
+ * PyTorch
82
+ * TensorRT-LLM
83
+ * TinyChat
84
+
85
+ **Test Hardware:**
86
+ * A100
87
+ * Jetson Orin
88
+ * RTX 4090
89
+
90
+ ## Ethical Considerations
91
+ NVIDIA believes Trustworthy AI is a shared responsibility and we have established policies and practices to enable development for a wide array of AI applications. When downloaded or used in accordance with our terms of service, developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case and addresses unforeseen product misuse.
base_projector.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import re
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
22
+
23
+
24
+ class IdentityMap(nn.Module):
25
+ def __init__(self):
26
+ super().__init__()
27
+
28
+ def forward(self, x, *args, **kwargs):
29
+ return x
30
+
31
+ @property
32
+ def config(self):
33
+ return {"mm_projector_type": "identity"}
34
+
35
+
36
+ class SimpleResBlock(nn.Module):
37
+ def __init__(self, channels):
38
+ super().__init__()
39
+ self.pre_norm = nn.LayerNorm(channels)
40
+
41
+ self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels))
42
+
43
+ def forward(self, x):
44
+ x = self.pre_norm(x)
45
+ return x + self.proj(x)
46
+
47
+
48
+ class DownSampleBlock(nn.Module):
49
+ def forward(self, x):
50
+ vit_embeds = x
51
+ h = w = int(vit_embeds.shape[1] ** 0.5)
52
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
53
+ vit_embeds = self.flat_square(vit_embeds)
54
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
55
+ return vit_embeds
56
+
57
+ def flat_square(self, x):
58
+ n, w, h, c = x.size()
59
+ if w % 2 == 1:
60
+ x = torch.concat([x, torch.zeros((n, 1, h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
61
+ n, w, h, c = x.size()
62
+ if h % 2 == 1:
63
+ x = torch.concat([x, torch.zeros((n, w, 1, c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
64
+ n, w, h, c = x.size()
65
+ x = x.contiguous()
66
+ x = x.view(n, w, int(h / 2), int(c * 2))
67
+ x = x.permute(0, 2, 1, 3).contiguous()
68
+ x = x.view(n, int(h / 2), int(w / 2), int(c * 4))
69
+ x = x.permute(0, 2, 1, 3).contiguous()
70
+ return x
71
+
72
+
73
+ class DownSample2x2BlockFix(nn.Module):
74
+ def forward(self, x):
75
+ vit_embeds = x
76
+ h = w = int(vit_embeds.shape[1] ** 0.5)
77
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
78
+ vit_embeds = flat_square_2x2(vit_embeds)
79
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
80
+ return vit_embeds
81
+
82
+
83
+ def flat_square_2x2(x):
84
+ n, w, h, c = x.size()
85
+ if w % 2 == 1:
86
+ x = torch.concat([x, torch.zeros((n, 1, h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
87
+ n, w, h, c = x.size()
88
+ x = x.contiguous()
89
+ if h % 2 == 1:
90
+ x = torch.concat([x, torch.zeros((n, w, 1, c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
91
+ n, w, h, c = x.size()
92
+ x = x.view(n, w, int(h / 2), int(c * 2))
93
+ x = x.permute(0, 2, 1, 3).contiguous()
94
+ x = x.view(n, int(h / 2), int(w / 2), int(c * 4))
95
+ x = x.permute(0, 2, 1, 3).contiguous()
96
+ return x
97
+
98
+
99
+ class DownSample3x3BlockFix(nn.Module):
100
+ def forward(self, x):
101
+ vit_embeds = x
102
+ h = w = int(vit_embeds.shape[1] ** 0.5)
103
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
104
+ vit_embeds = flat_square_3x3(vit_embeds)
105
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
106
+ return vit_embeds
107
+
108
+
109
+ def flat_square_3x3(x):
110
+ n, w, h, c = x.size()
111
+ if w % 3 != 0:
112
+ x = torch.concat([x, torch.zeros((n, 3 - (w % 3), h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
113
+ n, w, h, c = x.size()
114
+ x = x.contiguous()
115
+ if h % 3 != 0:
116
+ x = torch.concat([x, torch.zeros((n, w, 3 - (h % 3), c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
117
+ n, w, h, c = x.size()
118
+ x = x.view(n, w, int(h / 3), int(c * 3))
119
+ x = x.permute(0, 2, 1, 3).contiguous()
120
+ x = x.view(n, int(h / 3), int(w / 3), int(c * 9))
121
+ x = x.permute(0, 2, 1, 3).contiguous()
122
+ return x
123
+
124
+
125
+ class MultimodalProjectorConfig(PretrainedConfig):
126
+ model_type = "v2l_projector"
127
+
128
+ def __init__(self, mm_projector_type: str = None, **kwargs):
129
+ super().__init__()
130
+ self.mm_projector_type = mm_projector_type
131
+
132
+
133
+ class MultimodalProjector(PreTrainedModel):
134
+ config_class = MultimodalProjectorConfig
135
+
136
+ def __init__(self, mm_projector_cfg: MultimodalProjectorConfig, config: PretrainedConfig):
137
+ super().__init__(mm_projector_cfg)
138
+ mm_projector_type = mm_projector_cfg.mm_projector_type
139
+ self.downsample_rate = 1
140
+ if mm_projector_type == "identity":
141
+ self.layers = IdentityMap()
142
+ elif mm_projector_type == "linear":
143
+ self.layers = nn.Linear(config.mm_hidden_size, config.hidden_size)
144
+ elif mm_projector_type == "mlp_downsample":
145
+ self.layers = nn.Sequential(
146
+ DownSampleBlock(),
147
+ nn.LayerNorm(config.mm_hidden_size * 4),
148
+ nn.Linear(config.mm_hidden_size * 4, config.hidden_size),
149
+ nn.GELU(),
150
+ nn.Linear(config.hidden_size, config.hidden_size),
151
+ )
152
+ self.downsample_rate = 2
153
+ elif mm_projector_type == "mlp_downsample_2x2_fix":
154
+ self.layers = nn.Sequential(
155
+ DownSample2x2BlockFix(),
156
+ nn.LayerNorm(config.mm_hidden_size * 4),
157
+ nn.Linear(config.mm_hidden_size * 4, config.hidden_size),
158
+ nn.GELU(),
159
+ nn.Linear(config.hidden_size, config.hidden_size),
160
+ )
161
+ self.downsample_rate = 2
162
+ elif mm_projector_type == "mlp_downsample_3x3_fix":
163
+ self.layers = nn.Sequential(
164
+ DownSample3x3BlockFix(),
165
+ nn.LayerNorm(config.mm_hidden_size * 9),
166
+ nn.Linear(config.mm_hidden_size * 9, config.mm_hidden_size * 3),
167
+ nn.GELU(),
168
+ nn.LayerNorm(config.mm_hidden_size * 3),
169
+ nn.Linear(config.mm_hidden_size * 3, config.hidden_size),
170
+ nn.GELU(),
171
+ nn.Linear(config.hidden_size, config.hidden_size),
172
+ )
173
+ self.downsample_rate = 3
174
+ elif mm_projector_type == "mlp_downsample_3x3_s2":
175
+ self.layers = nn.Sequential(
176
+ DownSample3x3BlockFix(),
177
+ nn.LayerNorm(config.mm_hidden_size * 9),
178
+ nn.Linear(config.mm_hidden_size * 9, config.mm_hidden_size * 3),
179
+ nn.GELU(),
180
+ nn.LayerNorm(config.mm_hidden_size * 3),
181
+ nn.Linear(config.mm_hidden_size * 3, config.mm_hidden_size),
182
+ nn.GELU(),
183
+ nn.LayerNorm(config.mm_hidden_size),
184
+ nn.Linear(config.mm_hidden_size, config.mm_hidden_size // 3),
185
+ nn.GELU(),
186
+ nn.LayerNorm(config.mm_hidden_size // 3),
187
+ nn.Linear(config.mm_hidden_size // 3, config.hidden_size),
188
+ nn.GELU(),
189
+ nn.Linear(config.hidden_size, config.hidden_size),
190
+ )
191
+ elif mm_projector_type == "mlp_downsample_3x3_s2_new":
192
+ self.layers = nn.Sequential(
193
+ DownSample3x3BlockFix(),
194
+ nn.LayerNorm(config.mm_hidden_size * 9),
195
+ nn.Linear(config.mm_hidden_size * 9, config.mm_hidden_size * 4),
196
+ nn.GELU(),
197
+ nn.LayerNorm(config.mm_hidden_size * 4),
198
+ nn.Linear(config.mm_hidden_size * 4, config.mm_hidden_size * 2),
199
+ nn.GELU(),
200
+ nn.LayerNorm(config.mm_hidden_size * 2),
201
+ nn.Linear(config.mm_hidden_size * 2, config.mm_hidden_size),
202
+ nn.GELU(),
203
+ nn.LayerNorm(config.mm_hidden_size),
204
+ nn.Linear(config.mm_hidden_size, config.mm_hidden_size // 3),
205
+ nn.GELU(),
206
+ nn.LayerNorm(config.mm_hidden_size // 3),
207
+ nn.Linear(config.mm_hidden_size // 3, config.hidden_size),
208
+ nn.GELU(),
209
+ nn.Linear(config.hidden_size, config.hidden_size),
210
+ )
211
+ else:
212
+ mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", mm_projector_type)
213
+ if mlp_gelu_match:
214
+ mlp_depth = int(mlp_gelu_match.group(1))
215
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
216
+ for _ in range(1, mlp_depth):
217
+ modules.append(nn.GELU())
218
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
219
+ self.layers = nn.Sequential(*modules)
220
+ else:
221
+ raise ValueError(f"Unknown projector type: {mm_projector_type}")
222
+
223
+ def forward(self, x, *args, **kwargs):
224
+ return self.layers(x)
225
+
226
+
227
+ # AutoConfig.register("v2l_projector", MultimodalProjectorConfig)
228
+ # AutoModel.register(MultimodalProjectorConfig, MultimodalProjector)
builder.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import math
18
+ import os
19
+ import os.path as osp
20
+ import warnings
21
+ from dataclasses import asdict
22
+ from typing import Any, Dict, List, Optional, Sequence, Tuple
23
+
24
+ import torch
25
+ from huggingface_hub import file_exists, repo_exists
26
+ from huggingface_hub.utils import HFValidationError
27
+ import transformers
28
+ from transformers import (
29
+ AutoConfig,
30
+ AutoModelForCausalLM,
31
+ AutoTokenizer,
32
+ PretrainedConfig,
33
+ PreTrainedModel,
34
+ PreTrainedTokenizer,
35
+ )
36
+ # from .conversation import *
37
+ from .conversation import default_conversation, SeparatorStyle
38
+
39
+ SENTINEL_TOKEN = "<vila/sentinel>"
40
+ MEDIA_TOKENS = {
41
+ "image": "<image>",
42
+ "video": "<vila/video>",
43
+ }
44
+
45
+ # from llava.model.utils import packing
46
+ # from llava.utils.logging import logger
47
+ # from llava.utils.tokenizer import infer_stop_tokens
48
+
49
+ DUMMY_CONVERSATION = [
50
+ {"from": "human", "value": "question"},
51
+ {"from": "gpt", "value": "answer"},
52
+ ] * 10
53
+
54
+ def tokenizer_image_token(prompt, tokenizer, return_tensors=None):
55
+ return tokenizer(prompt, return_tensors=return_tensors).input_ids[0]
56
+
57
+ def has_tokenizer(repo_id_or_path: str) -> bool:
58
+ # Check if the tokenizer is in a local directory
59
+ if osp.exists(osp.join(repo_id_or_path, "tokenizer_config.json")):
60
+ return True
61
+
62
+ # Check if the tokenizer is in a Hugging Face Hub repo
63
+ try:
64
+ return repo_exists(repo_id_or_path) and file_exists(repo_id_or_path, "tokenizer_config.json")
65
+ except HFValidationError:
66
+ return False
67
+
68
+ def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None:
69
+ if not hasattr(tokenizer, "sentinel_token"):
70
+ tokenizer.add_tokens([SENTINEL_TOKEN], special_tokens=True)
71
+ tokenizer.sentinel_token = SENTINEL_TOKEN
72
+ tokenizer.sentinel_token_id = tokenizer.convert_tokens_to_ids(SENTINEL_TOKEN)
73
+
74
+ def tokenize_conversation_legacy(
75
+ messages: Sequence[Dict[str, str]],
76
+ tokenizer: transformers.PreTrainedTokenizer,
77
+ add_generation_prompt: bool = False,
78
+ overrides: Optional[Dict[str, str]] = None,
79
+ no_system_prompt: bool = False,
80
+ ) -> torch.Tensor:
81
+ conv = default_conversation.copy()
82
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
83
+
84
+ if no_system_prompt:
85
+ conv.system = ""
86
+
87
+ # Skip the first message if it is not from human
88
+ if messages[0]["from"] != "human":
89
+ messages = messages[1:]
90
+
91
+ # Add a generation prompt if needed
92
+ if add_generation_prompt:
93
+ messages.append({"from": "gpt", "value": None})
94
+
95
+ conv.messages = []
96
+ for turn, message in enumerate(messages):
97
+ role = roles[message["from"]]
98
+ assert role == conv.roles[turn % 2]
99
+ if overrides is not None and message["from"] in overrides:
100
+ conv.append_message(role, overrides[message["from"]])
101
+ else:
102
+ conv.append_message(role, message["value"])
103
+
104
+ return tokenizer_image_token(conv.get_prompt(), tokenizer, return_tensors="pt")
105
+
106
+ def tokenize_conversation(
107
+ messages: Sequence[Dict[str, str]],
108
+ tokenizer: transformers.PreTrainedTokenizer,
109
+ add_generation_prompt: bool = False,
110
+ overrides: Optional[Dict[str, str]] = None,
111
+ no_system_prompt: bool = False,
112
+ ) -> torch.Tensor:
113
+ # Normalize the conversation before tokenization
114
+ for message in messages:
115
+ message["value"] = message["value"].strip()
116
+
117
+ if default_conversation.sep_style != SeparatorStyle.AUTO:
118
+ return tokenize_conversation_legacy(
119
+ messages,
120
+ tokenizer,
121
+ add_generation_prompt=add_generation_prompt,
122
+ overrides=overrides,
123
+ no_system_prompt=no_system_prompt,
124
+ )
125
+
126
+ conversation = []
127
+ for m in messages:
128
+ message = {}
129
+ if m["from"] == "human":
130
+ message["role"] = "user"
131
+ elif m["from"] == "gpt":
132
+ message["role"] = "assistant"
133
+ else:
134
+ raise ValueError(f"Unexpected sender '{m['from']}' in conversation entry.")
135
+
136
+ message["content"] = m["value"]
137
+ if overrides is not None and m["from"] in overrides:
138
+ message["content"] = overrides[m["from"]]
139
+ conversation.append(message)
140
+
141
+ if no_system_prompt:
142
+ conversation = [{"role": "system", "content": ""}] + conversation
143
+
144
+ text = tokenizer.apply_chat_template(
145
+ conversation,
146
+ add_generation_prompt=add_generation_prompt,
147
+ tokenize=False,
148
+ )
149
+ return tokenizer_image_token(text, tokenizer, return_tensors="pt")
150
+
151
+ def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]:
152
+ _maybe_add_sentinel_token(tokenizer)
153
+ template = tokenize_conversation(DUMMY_CONVERSATION, tokenizer, overrides={"gpt": SENTINEL_TOKEN})
154
+
155
+ stop_tokens = {tokenizer.eos_token}
156
+ for k in range(template.size(0) - 1):
157
+ if template[k] == tokenizer.sentinel_token_id:
158
+ stop_token = tokenizer.decode(template[k + 1])
159
+ stop_tokens.add(stop_token)
160
+ return list(stop_tokens)
161
+
162
+ def context_length_extension(config):
163
+ orig_ctx_len = getattr(config, "max_position_embeddings", None)
164
+ model_max_length = getattr(config, "model_max_length", None)
165
+ if orig_ctx_len and model_max_length > orig_ctx_len:
166
+ print(f"Scaling RoPE from {orig_ctx_len} to {model_max_length}")
167
+ scaling_factor = float(math.ceil(model_max_length / orig_ctx_len))
168
+ config.rope_scaling = {"type": "linear", "factor": scaling_factor}
169
+ return config
170
+
171
+
172
+ def build_llm_and_tokenizer(
173
+ model_name_or_path: str,
174
+ config: PretrainedConfig,
175
+ attn_implementation=None,
176
+ model_max_length=None,
177
+ *args,
178
+ **kwargs,
179
+ ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
180
+ # print(model_name_or_path)
181
+ llm_cfg = AutoConfig.from_pretrained(model_name_or_path)
182
+ llm_cfg._attn_implementation = attn_implementation
183
+ llm_cfg.model_max_length = model_max_length
184
+ if model_max_length is not None:
185
+ context_length_extension(llm_cfg)
186
+
187
+ # Quantization related
188
+ quantization_restore_from_checkpoint = False
189
+
190
+ if quantization_restore_from_checkpoint:
191
+ fp8_model_name_or_path = kwargs.pop("fp8_llm_cfg", None)
192
+
193
+ llm = AutoModelForCausalLM.from_pretrained(
194
+ fp8_model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs
195
+ )
196
+ else:
197
+ llm = AutoModelForCausalLM.from_pretrained(
198
+ model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs
199
+ )
200
+ # NOTE(ligeng): not sure whether it affects the training
201
+ # packing.patch(llm)
202
+
203
+ # Locate the tokenizer.
204
+ llm_path = model_name_or_path
205
+ if not has_tokenizer(llm_path):
206
+ llm_path = osp.join(llm_path, "llm")
207
+ if not has_tokenizer(llm_path):
208
+ raise ValueError(f"Cannot find tokenizer in {llm_path}.")
209
+
210
+ tokenizer = AutoTokenizer.from_pretrained(llm_path, padding_side="right", use_fast=True, legacy=False)
211
+ if model_max_length is not None:
212
+ tokenizer.model_max_length = model_max_length
213
+
214
+ # Load chat template if specified.
215
+ if getattr(config, "chat_template", None) is not None:
216
+ print(f"Using chat template: {config.chat_template}")
217
+ fpath = os.path.join(os.path.dirname(__file__), "chat_templates", f"{config.chat_template}.jinja")
218
+ with open(fpath) as fd:
219
+ chat_template = fd.read()
220
+ tokenizer.chat_template = chat_template.replace(" ", "").replace("\n", "")
221
+
222
+ # NOTE(ligeng): disable temporarially, let see will any bugs introduce
223
+ # Set stop tokens for the tokenizer
224
+ tokenizer.stop_tokens = infer_stop_tokens(tokenizer)
225
+ tokenizer.stop_token_ids = tokenizer.convert_tokens_to_ids(tokenizer.stop_tokens)
226
+
227
+ # Add media tokens to the tokenizer
228
+ tokenizer.media_tokens = MEDIA_TOKENS
229
+ tokenizer.media_token_ids = {}
230
+ for name, token in MEDIA_TOKENS.items():
231
+ tokenizer.add_tokens([token], special_tokens=True)
232
+ tokenizer.media_token_ids[name] = tokenizer.convert_tokens_to_ids(token)
233
+
234
+ # TODO(ligeng): is this necessary for llava?
235
+ config.hidden_size = llm.config.hidden_size
236
+ return llm, tokenizer
config.json ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_attn_implementation_autoset": true,
3
+ "_name_or_path": "runs/train/qwen25-7b-3x3-sft-20241115225329/model",
4
+ "architectures": [
5
+ "VILAForCasualLM"
6
+ ],
7
+ "chat_template": null,
8
+ "drop_path_rate": 0.0,
9
+ "dynamic_s2": false,
10
+ "fps": 0.0,
11
+ "hidden_size": 3584,
12
+ "image_aspect_ratio": "dynamic",
13
+ "interpolate_mode": "linear",
14
+ "llm_cfg": {
15
+ "_attn_implementation_autoset": false,
16
+ "_name_or_path": "runs/train/qwen25-7b-3x3-sft-20241115225329/model/llm",
17
+ "add_cross_attention": false,
18
+ "architectures": [
19
+ "Qwen2ForCausalLM"
20
+ ],
21
+ "attention_dropout": 0.0,
22
+ "bad_words_ids": null,
23
+ "begin_suppress_tokens": null,
24
+ "bos_token_id": 151643,
25
+ "chunk_size_feed_forward": 0,
26
+ "cross_attention_hidden_size": null,
27
+ "decoder_start_token_id": null,
28
+ "diversity_penalty": 0.0,
29
+ "do_sample": false,
30
+ "early_stopping": false,
31
+ "encoder_no_repeat_ngram_size": 0,
32
+ "eos_token_id": 151645,
33
+ "exponential_decay_length_penalty": null,
34
+ "finetuning_task": null,
35
+ "forced_bos_token_id": null,
36
+ "forced_eos_token_id": null,
37
+ "hidden_act": "silu",
38
+ "hidden_size": 3584,
39
+ "id2label": {
40
+ "0": "LABEL_0",
41
+ "1": "LABEL_1"
42
+ },
43
+ "initializer_range": 0.02,
44
+ "intermediate_size": 18944,
45
+ "is_decoder": false,
46
+ "is_encoder_decoder": false,
47
+ "label2id": {
48
+ "LABEL_0": 0,
49
+ "LABEL_1": 1
50
+ },
51
+ "length_penalty": 1.0,
52
+ "max_length": 20,
53
+ "max_position_embeddings": 32768,
54
+ "max_window_layers": 28,
55
+ "min_length": 0,
56
+ "model_max_length": 4096,
57
+ "model_type": "qwen2",
58
+ "no_repeat_ngram_size": 0,
59
+ "num_attention_heads": 28,
60
+ "num_beam_groups": 1,
61
+ "num_beams": 1,
62
+ "num_hidden_layers": 28,
63
+ "num_key_value_heads": 4,
64
+ "num_return_sequences": 1,
65
+ "output_attentions": false,
66
+ "output_hidden_states": false,
67
+ "output_scores": false,
68
+ "pad_token_id": null,
69
+ "prefix": null,
70
+ "problem_type": null,
71
+ "pruned_heads": {},
72
+ "remove_invalid_values": false,
73
+ "repetition_penalty": 1.0,
74
+ "return_dict": true,
75
+ "return_dict_in_generate": false,
76
+ "rms_norm_eps": 1e-06,
77
+ "rope_scaling": null,
78
+ "rope_theta": 1000000.0,
79
+ "sep_token_id": null,
80
+ "sliding_window": null,
81
+ "suppress_tokens": null,
82
+ "task_specific_params": null,
83
+ "temperature": 1.0,
84
+ "tf_legacy_loss": false,
85
+ "tie_encoder_decoder": false,
86
+ "tie_word_embeddings": false,
87
+ "tokenizer_class": null,
88
+ "tokenizer_model_max_length": 4096,
89
+ "tokenizer_padding_side": "right",
90
+ "top_k": 50,
91
+ "top_p": 1.0,
92
+ "torch_dtype": "bfloat16",
93
+ "torchscript": false,
94
+ "typical_p": 1.0,
95
+ "use_bfloat16": false,
96
+ "use_cache": true,
97
+ "use_sliding_window": false,
98
+ "vocab_size": 151648
99
+ },
100
+ "mm_hidden_size": 1152,
101
+ "mm_projector_cfg": {
102
+ "_attn_implementation_autoset": false,
103
+ "_name_or_path": "runs/train/qwen25-7b-3x3-sft-20241115225329/model/mm_projector",
104
+ "add_cross_attention": false,
105
+ "architectures": [
106
+ "MultimodalProjector"
107
+ ],
108
+ "bad_words_ids": null,
109
+ "begin_suppress_tokens": null,
110
+ "bos_token_id": null,
111
+ "chunk_size_feed_forward": 0,
112
+ "cross_attention_hidden_size": null,
113
+ "decoder_start_token_id": null,
114
+ "diversity_penalty": 0.0,
115
+ "do_sample": false,
116
+ "early_stopping": false,
117
+ "encoder_no_repeat_ngram_size": 0,
118
+ "eos_token_id": null,
119
+ "exponential_decay_length_penalty": null,
120
+ "finetuning_task": null,
121
+ "forced_bos_token_id": null,
122
+ "forced_eos_token_id": null,
123
+ "id2label": {
124
+ "0": "LABEL_0",
125
+ "1": "LABEL_1"
126
+ },
127
+ "is_decoder": false,
128
+ "is_encoder_decoder": false,
129
+ "label2id": {
130
+ "LABEL_0": 0,
131
+ "LABEL_1": 1
132
+ },
133
+ "length_penalty": 1.0,
134
+ "max_length": 20,
135
+ "min_length": 0,
136
+ "mm_projector_type": "mlp_downsample_3x3_fix",
137
+ "model_type": "v2l_projector",
138
+ "no_repeat_ngram_size": 0,
139
+ "num_beam_groups": 1,
140
+ "num_beams": 1,
141
+ "num_return_sequences": 1,
142
+ "output_attentions": false,
143
+ "output_hidden_states": false,
144
+ "output_scores": false,
145
+ "pad_token_id": null,
146
+ "prefix": null,
147
+ "problem_type": null,
148
+ "pruned_heads": {},
149
+ "remove_invalid_values": false,
150
+ "repetition_penalty": 1.0,
151
+ "return_dict": true,
152
+ "return_dict_in_generate": false,
153
+ "sep_token_id": null,
154
+ "suppress_tokens": null,
155
+ "task_specific_params": null,
156
+ "temperature": 1.0,
157
+ "tf_legacy_loss": false,
158
+ "tie_encoder_decoder": false,
159
+ "tie_word_embeddings": true,
160
+ "tokenizer_class": null,
161
+ "top_k": 50,
162
+ "top_p": 1.0,
163
+ "torch_dtype": "bfloat16",
164
+ "torchscript": false,
165
+ "typical_p": 1.0,
166
+ "use_bfloat16": false
167
+ },
168
+ "mm_projector_lr": null,
169
+ "mm_use_im_patch_token": true,
170
+ "mm_use_im_start_end": false,
171
+ "mm_vision_select_feature": "cls_patch",
172
+ "mm_vision_select_layer": -2,
173
+ "model_dtype": "torch.bfloat16",
174
+ "model_type": "vila",
175
+ "num_time_tokens": 0,
176
+ "num_video_frames": 8,
177
+ "resume_path": "runs/train/qwen25-7b-3x3-sft-20241115225329/model",
178
+ "s2": false,
179
+ "s2_max_split_size": 336,
180
+ "s2_resize_output_to_scale_idx": 0,
181
+ "s2_scales": "336,672,1008",
182
+ "soft_ce_std": 1.0,
183
+ "time_token_format": "<t{t}>",
184
+ "time_token_ids": [],
185
+ "transformers_version": "4.46.0",
186
+ "tune_language_model": true,
187
+ "tune_mm_projector": true,
188
+ "tune_vision_tower": true,
189
+ "vision_resolution": -1,
190
+ "vision_tower_cfg": {
191
+ "_attn_implementation_autoset": false,
192
+ "_name_or_path": "runs/train/qwen25-7b-3x3-sft-20241115225329/model/vision_tower",
193
+ "add_cross_attention": false,
194
+ "architectures": [
195
+ "SiglipVisionModel"
196
+ ],
197
+ "attention_dropout": 0.0,
198
+ "bad_words_ids": null,
199
+ "begin_suppress_tokens": null,
200
+ "bos_token_id": null,
201
+ "chunk_size_feed_forward": 0,
202
+ "cross_attention_hidden_size": null,
203
+ "decoder_start_token_id": null,
204
+ "diversity_penalty": 0.0,
205
+ "do_sample": false,
206
+ "early_stopping": false,
207
+ "encoder_no_repeat_ngram_size": 0,
208
+ "eos_token_id": null,
209
+ "exponential_decay_length_penalty": null,
210
+ "finetuning_task": null,
211
+ "forced_bos_token_id": null,
212
+ "forced_eos_token_id": null,
213
+ "hidden_act": "gelu_pytorch_tanh",
214
+ "hidden_size": 1152,
215
+ "id2label": {
216
+ "0": "LABEL_0",
217
+ "1": "LABEL_1"
218
+ },
219
+ "image_size": 448,
220
+ "intermediate_size": 4304,
221
+ "is_decoder": false,
222
+ "is_encoder_decoder": false,
223
+ "label2id": {
224
+ "LABEL_0": 0,
225
+ "LABEL_1": 1
226
+ },
227
+ "layer_norm_eps": 1e-06,
228
+ "length_penalty": 1.0,
229
+ "max_length": 20,
230
+ "min_length": 0,
231
+ "model_type": "siglip_vision_model",
232
+ "no_repeat_ngram_size": 0,
233
+ "num_attention_heads": 16,
234
+ "num_beam_groups": 1,
235
+ "num_beams": 1,
236
+ "num_channels": 3,
237
+ "num_hidden_layers": 27,
238
+ "num_image_tokens": 256,
239
+ "num_return_sequences": 1,
240
+ "output_attentions": false,
241
+ "output_hidden_states": false,
242
+ "output_scores": false,
243
+ "pad_token_id": null,
244
+ "patch_size": 14,
245
+ "prefix": null,
246
+ "problem_type": null,
247
+ "projection_dim": 2048,
248
+ "projector_hidden_act": "gelu_fast",
249
+ "pruned_heads": {},
250
+ "remove_invalid_values": false,
251
+ "repetition_penalty": 1.0,
252
+ "return_dict": true,
253
+ "return_dict_in_generate": false,
254
+ "sep_token_id": null,
255
+ "suppress_tokens": null,
256
+ "task_specific_params": null,
257
+ "temperature": 1.0,
258
+ "tf_legacy_loss": false,
259
+ "tie_encoder_decoder": false,
260
+ "tie_word_embeddings": true,
261
+ "tokenizer_class": null,
262
+ "top_k": 50,
263
+ "top_p": 1.0,
264
+ "torch_dtype": "bfloat16",
265
+ "torchscript": false,
266
+ "typical_p": 1.0,
267
+ "use_bfloat16": false,
268
+ "vision_use_head": false
269
+ },
270
+ "version": "2.0",
271
+ "auto_map": {
272
+ "AutoConfig": "modeling_vila.VILAConfig",
273
+ "AutoModel": "modeling_vila.VILAForCasualLM",
274
+ "AutoModelForCausalLM": "modeling_vila.VILAForCasualLM"
275
+ }
276
+ }
configuration_vila.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional
3
+ import json
4
+ import torch
5
+ import torchvision
6
+ import os, os.path as osp
7
+
8
+ from threading import Thread
9
+ from copy import deepcopy
10
+ from PIL import Image
11
+ from transformers import Qwen2Config, PretrainedConfig, PreTrainedModel
12
+ from transformers import AutoProcessor, Qwen2PreTrainedModel, Qwen2ForCausalLM, TextIteratorStreamer
13
+
14
+ class VILAConfig(PretrainedConfig):
15
+ model_type = "vila"
16
+ keys_to_ignore_at_inference = ["past_key_values"]
17
+
18
+ def __init__(
19
+ self,
20
+ llm_cfg=None,
21
+ vision_tower_cfg=None,
22
+ mm_projector_cfg=None,
23
+ architectures=None,
24
+ resume_path=None,
25
+ hidden_size=None,
26
+ mm_hidden_size=None,
27
+ image_aspect_ratio=None,
28
+ num_video_frames=None,
29
+ fps=None,
30
+ mm_vision_select_layer=None,
31
+ mm_vision_select_feature=None,
32
+ mm_use_im_start_end=False,
33
+ mm_use_im_patch_token=False,
34
+ mm_projector_lr=None,
35
+ vision_tower_lr=None,
36
+ vision_resolution=None,
37
+ interpolate_mode=None,
38
+ s2=None,
39
+ dynamic_s2=None,
40
+ s2_scales=None,
41
+ s2_max_split_size=None,
42
+ s2_resize_output_to_scale_idx=0,
43
+ min_tiles: Optional[int] = 1,
44
+ max_tiles: Optional[int] = 12,
45
+ num_time_tokens=None,
46
+ time_token_format=None,
47
+ image_encoder: str = '{"_target_": "llava.model.encoders.BasicImageEncoder"}',
48
+ video_encoder: str = '{"_target_": "llava.model.encoders.BasicVideoEncoder"}',
49
+ **kwargs,
50
+ ):
51
+ super().__init__()
52
+ self.architectures = architectures
53
+ self.llm_cfg = llm_cfg
54
+ self.vision_tower_cfg = vision_tower_cfg
55
+ self.mm_projector_cfg = mm_projector_cfg
56
+ self.resume_path = resume_path
57
+
58
+ self.hidden_size = hidden_size
59
+ self.mm_hidden_size = mm_hidden_size
60
+ self.image_aspect_ratio = image_aspect_ratio
61
+ self.num_video_frames = num_video_frames
62
+ self.fps = fps
63
+ self.mm_vision_select_layer = mm_vision_select_layer
64
+ self.mm_vision_select_feature = mm_vision_select_feature
65
+ self.mm_use_im_start_end = mm_use_im_start_end
66
+ self.mm_use_im_patch_token = mm_use_im_patch_token
67
+ self.mm_projector_lr = mm_projector_lr
68
+ self.vision_tower_lr = vision_tower_lr
69
+ self.vision_resolution = vision_resolution
70
+ self.interpolate_mode = interpolate_mode
71
+ self.s2 = s2
72
+ self.dynamic_s2 = dynamic_s2
73
+ self.s2_scales = s2_scales
74
+ self.s2_max_split_size = s2_max_split_size
75
+ self.s2_resize_output_to_scale_idx = s2_resize_output_to_scale_idx
76
+ self.min_tiles = min_tiles
77
+ self.max_tiles = max_tiles
78
+ self.num_time_tokens = num_time_tokens
79
+ self.time_token_format = time_token_format
80
+
81
+ self.image_encoder = image_encoder
82
+ self.video_encoder = video_encoder
83
+
84
+ super().__init__(**kwargs)
85
+
constants.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
18
+ WORKER_HEART_BEAT_INTERVAL = 15
19
+
20
+ LOGDIR = "."
21
+
22
+ # Model Constants
23
+ IGNORE_INDEX = -100
24
+ DEFAULT_IMAGE_TOKEN = "<image>"
25
+
26
+ SENTINEL_TOKEN = "<vila/sentinel>"
27
+ MEDIA_TOKENS = {
28
+ "image": "<image>",
29
+ "video": "<vila/video>",
30
+ }
31
+ # <image> <vila/video> <vila/sentinel>
32
+ # TODO(ligeng): need to discuss with Zhijian for the following tokens for different models.
33
+ """
34
+ 151643: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
35
+ 151644: AddedToken("<|im_start|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
36
+ 151645: AddedToken("<|im_end|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
37
+ 151646: AddedToken("[BOS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
38
+ 151647: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
39
+ 151648: AddedToken("<vila/sentinel>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
40
+ 151649: AddedToken("<image>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
41
+ 151650: AddedToken("<vila/video>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
42
+ """
43
+ NUM_EXTRA_TOKENS = 8
conversation.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
17
+
18
+ import dataclasses
19
+ from enum import Enum, auto
20
+ from typing import List
21
+
22
+ # from llava.utils.logging import logger
23
+
24
+
25
+ class SeparatorStyle(Enum):
26
+ """Different separator style."""
27
+
28
+ AUTO = auto()
29
+ TWO = auto()
30
+ MPT = auto()
31
+ PLAIN = auto()
32
+ LLAMA_3 = auto()
33
+
34
+
35
+ @dataclasses.dataclass
36
+ class Conversation:
37
+ """A class that keeps all conversation history."""
38
+
39
+ system: str
40
+ roles: List[str]
41
+ messages: List[List[str]]
42
+ sep_style: SeparatorStyle = SeparatorStyle.AUTO
43
+ sep: str = "###"
44
+ sep2: str = None
45
+ version: str = "Unknown"
46
+
47
+ def get_prompt(self):
48
+ messages = self.messages
49
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
50
+ messages = self.messages.copy()
51
+ init_role, init_msg = messages[0].copy()
52
+ init_msg = init_msg[0].replace("<image>", "").strip()
53
+ messages[0] = (init_role, "<image>\n" + init_msg)
54
+
55
+ if self.sep_style == SeparatorStyle.TWO:
56
+ seps = [self.sep, self.sep2]
57
+ ret = self.system + seps[0]
58
+ for i, (role, message) in enumerate(messages):
59
+ if message:
60
+ if type(message) is tuple:
61
+ message, _, _ = message
62
+ ret += role + ": " + message + seps[i % 2]
63
+ else:
64
+ ret += role + ":"
65
+ elif self.sep_style == SeparatorStyle.LLAMA_3:
66
+ ret = self.system + self.sep
67
+ for rid, (role, message) in enumerate(messages):
68
+ if message:
69
+ if type(message) is tuple:
70
+ message = message[0]
71
+ sep = self.sep if rid < len(messages) - 1 else self.sep2
72
+ ret += role + message + sep
73
+ else:
74
+ ret += role
75
+ elif self.sep_style == SeparatorStyle.MPT:
76
+ ret = self.system + self.sep
77
+ for role, message in messages:
78
+ if message:
79
+ if type(message) is tuple:
80
+ message, _, _ = message
81
+ ret += role + message + self.sep
82
+ else:
83
+ ret += role
84
+ elif self.sep_style == SeparatorStyle.PLAIN:
85
+ seps = [self.sep, self.sep2]
86
+ ret = self.system
87
+ for i, (role, message) in enumerate(messages):
88
+ if message:
89
+ if type(message) is tuple:
90
+ message, _, _ = message
91
+ ret += message + seps[i % 2]
92
+ else:
93
+ ret += ""
94
+ else:
95
+ raise ValueError(f"Invalid style: {self.sep_style}")
96
+
97
+ return ret
98
+
99
+ def append_message(self, role, message):
100
+ self.messages.append([role, message])
101
+
102
+ def copy(self):
103
+ return Conversation(
104
+ system=self.system,
105
+ roles=self.roles,
106
+ messages=[[x, y] for x, y in self.messages],
107
+ sep_style=self.sep_style,
108
+ sep=self.sep,
109
+ sep2=self.sep2,
110
+ version=self.version,
111
+ )
112
+
113
+
114
+ conv_auto = Conversation(
115
+ system="",
116
+ roles=("", ""),
117
+ messages=(),
118
+ sep_style=SeparatorStyle.AUTO,
119
+ sep="\n",
120
+ )
121
+
122
+ conv_vicuna_v1 = Conversation(
123
+ system="A chat between a curious user and an artificial intelligence assistant. "
124
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
125
+ roles=("USER", "ASSISTANT"),
126
+ version="v1",
127
+ messages=(),
128
+ sep_style=SeparatorStyle.TWO,
129
+ sep=" ",
130
+ sep2="</s>",
131
+ )
132
+
133
+ conv_llava_plain = Conversation(
134
+ system="",
135
+ roles=("", ""),
136
+ messages=(),
137
+ sep_style=SeparatorStyle.PLAIN,
138
+ sep="\n",
139
+ )
140
+
141
+ hermes_2 = Conversation(
142
+ system="<|im_start|>system\nAnswer the questions.",
143
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
144
+ sep_style=SeparatorStyle.MPT,
145
+ sep="<|im_end|>",
146
+ messages=(),
147
+ version="hermes-2",
148
+ )
149
+
150
+ # Template added by Yukang. Note (kentang-mit@): sep is <|eot_id|> for official template.
151
+ llama_3_chat = Conversation(
152
+ system="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. "
153
+ "You are able to understand the visual content that the user provides, "
154
+ "and assist the user with a variety of tasks using natural language.",
155
+ roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
156
+ version="llama_v3",
157
+ messages=(),
158
+ sep_style=SeparatorStyle.LLAMA_3,
159
+ sep="<|eot_id|>",
160
+ sep2="<|end_of_text|>",
161
+ )
162
+
163
+
164
+ default_conversation = conv_auto
165
+ conv_templates = {
166
+ "auto": conv_auto,
167
+ "hermes-2": hermes_2,
168
+ "llama_3": llama_3_chat,
169
+ "v1": conv_vicuna_v1,
170
+ "vicuna_v1": conv_vicuna_v1,
171
+ "plain": conv_llava_plain,
172
+ }
173
+
174
+
175
+ CONVERSATION_MODE_MAPPING = {
176
+ "vila1.5-3b": "vicuna_v1",
177
+ "vila1.5-8b": "llama_3",
178
+ "vila1.5-13b": "vicuna_v1",
179
+ "vila1.5-40b": "hermes-2",
180
+ "llama-3": "llama_3",
181
+ "llama3": "llama_3",
182
+ }
183
+
184
+
185
+ def auto_set_conversation_mode(model_name_or_path: str) -> str:
186
+ global default_conversation
187
+ for k, v in CONVERSATION_MODE_MAPPING.items():
188
+ if k in model_name_or_path.lower():
189
+ print(f"Setting conversation mode to `{v}` based on model name/path `{model_name_or_path}`.")
190
+ default_conversation = conv_templates[v]
191
+ return
llm/added_tokens.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "<|endoftext|>": 151643,
3
+ "<|im_end|>": 151645,
4
+ "<|im_start|>": 151644,
5
+ "[BOS]": 151646,
6
+ "[PAD]": 151647
7
+ }
llm/config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "runs/train/qwen25-7b-3x3-sft-20241115225329/model/llm",
3
+ "architectures": [
4
+ "Qwen2ForCausalLM"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "eos_token_id": 151645,
9
+ "hidden_act": "silu",
10
+ "hidden_size": 3584,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 18944,
13
+ "max_position_embeddings": 32768,
14
+ "max_window_layers": 28,
15
+ "model_max_length": 4096,
16
+ "model_type": "qwen2",
17
+ "num_attention_heads": 28,
18
+ "num_hidden_layers": 28,
19
+ "num_key_value_heads": 4,
20
+ "rms_norm_eps": 1e-06,
21
+ "rope_scaling": null,
22
+ "rope_theta": 1000000.0,
23
+ "sliding_window": null,
24
+ "tie_word_embeddings": false,
25
+ "tokenizer_model_max_length": 4096,
26
+ "tokenizer_padding_side": "right",
27
+ "torch_dtype": "bfloat16",
28
+ "transformers_version": "4.46.0",
29
+ "use_cache": true,
30
+ "use_sliding_window": false,
31
+ "vocab_size": 151648
32
+ }
llm/generation_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "do_sample": true,
4
+ "eos_token_id": [
5
+ 151645,
6
+ 151643
7
+ ],
8
+ "pad_token_id": 151643,
9
+ "repetition_penalty": 1.05,
10
+ "temperature": 0.7,
11
+ "top_k": 20,
12
+ "top_p": 0.8,
13
+ "transformers_version": "4.46.0"
14
+ }
llm/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
llm/model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02c537cc95bae7ddd1cf1278bf8cb78e5f09e81dce7bcaf4d2451ea9b49a5be4
3
+ size 4874678888
llm/model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:61120bdc643babbb5563301a8d4f67457335d4a538e6ad66507c778fbe262eff
3
+ size 4932751008
llm/model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:95ddf0ba4eda4bafce13a0bb6f8436ac80efd061153a9fe2c3d5dec3c54d5ca7
3
+ size 4330865200
llm/model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08b87c792d141f74118ac04eb9b7c96f2961189dd844e124ed03fa726afff263
3
+ size 1087012992
llm/model.safetensors.index.json ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 15225269248
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "model-00004-of-00004.safetensors",
7
+ "model.embed_tokens.weight": "model-00001-of-00004.safetensors",
8
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
9
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
10
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
11
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
12
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
13
+ "model.layers.0.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
14
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
15
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
16
+ "model.layers.0.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
17
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
18
+ "model.layers.0.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
19
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
20
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
21
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
22
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
23
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
24
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
25
+ "model.layers.1.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
26
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
27
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
28
+ "model.layers.1.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
29
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
30
+ "model.layers.1.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
31
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
32
+ "model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
33
+ "model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
34
+ "model.layers.10.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
35
+ "model.layers.10.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
36
+ "model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
37
+ "model.layers.10.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
38
+ "model.layers.10.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
39
+ "model.layers.10.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
40
+ "model.layers.10.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
41
+ "model.layers.10.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
42
+ "model.layers.10.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
43
+ "model.layers.10.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
44
+ "model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
45
+ "model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
46
+ "model.layers.11.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
47
+ "model.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
48
+ "model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
49
+ "model.layers.11.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
50
+ "model.layers.11.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
51
+ "model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
52
+ "model.layers.11.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
53
+ "model.layers.11.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
54
+ "model.layers.11.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
55
+ "model.layers.11.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
56
+ "model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
57
+ "model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
58
+ "model.layers.12.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
59
+ "model.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
60
+ "model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
61
+ "model.layers.12.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
62
+ "model.layers.12.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
63
+ "model.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
64
+ "model.layers.12.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
65
+ "model.layers.12.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
66
+ "model.layers.12.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
67
+ "model.layers.12.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
68
+ "model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
69
+ "model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
70
+ "model.layers.13.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
71
+ "model.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
72
+ "model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
73
+ "model.layers.13.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
74
+ "model.layers.13.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
75
+ "model.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
76
+ "model.layers.13.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
77
+ "model.layers.13.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
78
+ "model.layers.13.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
79
+ "model.layers.13.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
80
+ "model.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
81
+ "model.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
82
+ "model.layers.14.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
83
+ "model.layers.14.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
84
+ "model.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
85
+ "model.layers.14.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
86
+ "model.layers.14.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
87
+ "model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
88
+ "model.layers.14.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
89
+ "model.layers.14.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
90
+ "model.layers.14.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
91
+ "model.layers.14.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
92
+ "model.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
93
+ "model.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
94
+ "model.layers.15.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
95
+ "model.layers.15.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
96
+ "model.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
97
+ "model.layers.15.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
98
+ "model.layers.15.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
99
+ "model.layers.15.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
100
+ "model.layers.15.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
101
+ "model.layers.15.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
102
+ "model.layers.15.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
103
+ "model.layers.15.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
104
+ "model.layers.16.input_layernorm.weight": "model-00002-of-00004.safetensors",
105
+ "model.layers.16.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
106
+ "model.layers.16.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
107
+ "model.layers.16.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
108
+ "model.layers.16.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
109
+ "model.layers.16.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
110
+ "model.layers.16.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
111
+ "model.layers.16.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
112
+ "model.layers.16.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
113
+ "model.layers.16.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
114
+ "model.layers.16.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
115
+ "model.layers.16.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
116
+ "model.layers.17.input_layernorm.weight": "model-00002-of-00004.safetensors",
117
+ "model.layers.17.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
118
+ "model.layers.17.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
119
+ "model.layers.17.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
120
+ "model.layers.17.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
121
+ "model.layers.17.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
122
+ "model.layers.17.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
123
+ "model.layers.17.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
124
+ "model.layers.17.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
125
+ "model.layers.17.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
126
+ "model.layers.17.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
127
+ "model.layers.17.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
128
+ "model.layers.18.input_layernorm.weight": "model-00003-of-00004.safetensors",
129
+ "model.layers.18.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
130
+ "model.layers.18.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
131
+ "model.layers.18.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
132
+ "model.layers.18.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
133
+ "model.layers.18.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
134
+ "model.layers.18.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
135
+ "model.layers.18.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
136
+ "model.layers.18.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
137
+ "model.layers.18.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
138
+ "model.layers.18.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
139
+ "model.layers.18.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
140
+ "model.layers.19.input_layernorm.weight": "model-00003-of-00004.safetensors",
141
+ "model.layers.19.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
142
+ "model.layers.19.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
143
+ "model.layers.19.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
144
+ "model.layers.19.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
145
+ "model.layers.19.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
146
+ "model.layers.19.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
147
+ "model.layers.19.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
148
+ "model.layers.19.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
149
+ "model.layers.19.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
150
+ "model.layers.19.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
151
+ "model.layers.19.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
152
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
153
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
154
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
155
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
156
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
157
+ "model.layers.2.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
158
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
159
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
160
+ "model.layers.2.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
161
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
162
+ "model.layers.2.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
163
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
164
+ "model.layers.20.input_layernorm.weight": "model-00003-of-00004.safetensors",
165
+ "model.layers.20.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
166
+ "model.layers.20.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
167
+ "model.layers.20.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
168
+ "model.layers.20.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
169
+ "model.layers.20.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
170
+ "model.layers.20.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
171
+ "model.layers.20.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
172
+ "model.layers.20.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
173
+ "model.layers.20.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
174
+ "model.layers.20.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
175
+ "model.layers.20.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
176
+ "model.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors",
177
+ "model.layers.21.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
178
+ "model.layers.21.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
179
+ "model.layers.21.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
180
+ "model.layers.21.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
181
+ "model.layers.21.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
182
+ "model.layers.21.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
183
+ "model.layers.21.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
184
+ "model.layers.21.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
185
+ "model.layers.21.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
186
+ "model.layers.21.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
187
+ "model.layers.21.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
188
+ "model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
189
+ "model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
190
+ "model.layers.22.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
191
+ "model.layers.22.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
192
+ "model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
193
+ "model.layers.22.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
194
+ "model.layers.22.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
195
+ "model.layers.22.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
196
+ "model.layers.22.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
197
+ "model.layers.22.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
198
+ "model.layers.22.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
199
+ "model.layers.22.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
200
+ "model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
201
+ "model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
202
+ "model.layers.23.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
203
+ "model.layers.23.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
204
+ "model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
205
+ "model.layers.23.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
206
+ "model.layers.23.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
207
+ "model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
208
+ "model.layers.23.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
209
+ "model.layers.23.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
210
+ "model.layers.23.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
211
+ "model.layers.23.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
212
+ "model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
213
+ "model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
214
+ "model.layers.24.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
215
+ "model.layers.24.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
216
+ "model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
217
+ "model.layers.24.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
218
+ "model.layers.24.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
219
+ "model.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
220
+ "model.layers.24.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
221
+ "model.layers.24.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
222
+ "model.layers.24.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
223
+ "model.layers.24.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
224
+ "model.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
225
+ "model.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
226
+ "model.layers.25.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
227
+ "model.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
228
+ "model.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
229
+ "model.layers.25.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
230
+ "model.layers.25.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
231
+ "model.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
232
+ "model.layers.25.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
233
+ "model.layers.25.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
234
+ "model.layers.25.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
235
+ "model.layers.25.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
236
+ "model.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors",
237
+ "model.layers.26.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
238
+ "model.layers.26.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
239
+ "model.layers.26.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
240
+ "model.layers.26.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
241
+ "model.layers.26.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
242
+ "model.layers.26.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
243
+ "model.layers.26.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
244
+ "model.layers.26.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
245
+ "model.layers.26.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
246
+ "model.layers.26.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
247
+ "model.layers.26.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
248
+ "model.layers.27.input_layernorm.weight": "model-00003-of-00004.safetensors",
249
+ "model.layers.27.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
250
+ "model.layers.27.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
251
+ "model.layers.27.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
252
+ "model.layers.27.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
253
+ "model.layers.27.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
254
+ "model.layers.27.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
255
+ "model.layers.27.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
256
+ "model.layers.27.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
257
+ "model.layers.27.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
258
+ "model.layers.27.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
259
+ "model.layers.27.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
260
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors",
261
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
262
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
263
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
264
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
265
+ "model.layers.3.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
266
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
267
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
268
+ "model.layers.3.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
269
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
270
+ "model.layers.3.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
271
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
272
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors",
273
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
274
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
275
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
276
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
277
+ "model.layers.4.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
278
+ "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
279
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
280
+ "model.layers.4.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
281
+ "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
282
+ "model.layers.4.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
283
+ "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
284
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00004.safetensors",
285
+ "model.layers.5.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
286
+ "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
287
+ "model.layers.5.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
288
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
289
+ "model.layers.5.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
290
+ "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
291
+ "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
292
+ "model.layers.5.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
293
+ "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
294
+ "model.layers.5.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
295
+ "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
296
+ "model.layers.6.input_layernorm.weight": "model-00001-of-00004.safetensors",
297
+ "model.layers.6.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
298
+ "model.layers.6.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
299
+ "model.layers.6.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
300
+ "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
301
+ "model.layers.6.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
302
+ "model.layers.6.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
303
+ "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
304
+ "model.layers.6.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
305
+ "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
306
+ "model.layers.6.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
307
+ "model.layers.6.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
308
+ "model.layers.7.input_layernorm.weight": "model-00001-of-00004.safetensors",
309
+ "model.layers.7.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
310
+ "model.layers.7.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
311
+ "model.layers.7.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
312
+ "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
313
+ "model.layers.7.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
314
+ "model.layers.7.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
315
+ "model.layers.7.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
316
+ "model.layers.7.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
317
+ "model.layers.7.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
318
+ "model.layers.7.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
319
+ "model.layers.7.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
320
+ "model.layers.8.input_layernorm.weight": "model-00002-of-00004.safetensors",
321
+ "model.layers.8.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
322
+ "model.layers.8.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
323
+ "model.layers.8.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
324
+ "model.layers.8.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
325
+ "model.layers.8.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
326
+ "model.layers.8.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
327
+ "model.layers.8.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
328
+ "model.layers.8.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
329
+ "model.layers.8.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
330
+ "model.layers.8.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
331
+ "model.layers.8.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
332
+ "model.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
333
+ "model.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
334
+ "model.layers.9.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
335
+ "model.layers.9.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
336
+ "model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
337
+ "model.layers.9.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
338
+ "model.layers.9.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
339
+ "model.layers.9.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
340
+ "model.layers.9.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
341
+ "model.layers.9.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
342
+ "model.layers.9.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
343
+ "model.layers.9.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
344
+ "model.norm.weight": "model-00003-of-00004.safetensors"
345
+ }
346
+ }
llm/special_tokens_map.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>"
5
+ ],
6
+ "bos_token": {
7
+ "content": "[BOS]",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false
12
+ },
13
+ "eos_token": {
14
+ "content": "<|im_end|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false
19
+ },
20
+ "pad_token": {
21
+ "content": "[PAD]",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false
26
+ }
27
+ }
llm/tokenizer_config.json ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "151643": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "151644": {
13
+ "content": "<|im_start|>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "151645": {
21
+ "content": "<|im_end|>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "151646": {
29
+ "content": "[BOS]",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "151647": {
37
+ "content": "[PAD]",
38
+ "lstrip": false,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ }
44
+ },
45
+ "additional_special_tokens": [
46
+ "<|im_start|>",
47
+ "<|im_end|>"
48
+ ],
49
+ "bos_token": "[BOS]",
50
+ "chat_template": "{% if messages[0]['role'] != 'system' %}{{ '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}{% endif %}{% for message in messages if message['content'] is not none %}{{ '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
51
+ "clean_up_tokenization_spaces": false,
52
+ "eos_token": "<|im_end|>",
53
+ "errors": "replace",
54
+ "legacy": false,
55
+ "model_max_length": 4096,
56
+ "pad_token": "[PAD]",
57
+ "padding_side": "right",
58
+ "split_special_tokens": false,
59
+ "tokenizer_class": "Qwen2Tokenizer",
60
+ "unk_token": null
61
+ }
llm/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
main.py ADDED
File without changes
media.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ from collections import defaultdict
4
+ from typing import Any, Dict, List, Optional, Union
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import PIL
9
+ import PIL.Image
10
+ import requests
11
+ from transformers import PretrainedConfig
12
+
13
+ # from llava.constants import MEDIA_TOKENS
14
+ # from llava.media import Image, Video
15
+ # from llava.utils import make_list
16
+ # from llava.utils.logging import logger
17
+
18
+ MEDIA_TOKENS = {
19
+ "image": "<image>",
20
+ "video": "<vila/video>",
21
+ }
22
+
23
+ class Media:
24
+ pass
25
+
26
+ class File(Media):
27
+ def __init__(self, path: str) -> None:
28
+ self.path = path
29
+
30
+ class Image(File):
31
+ pass
32
+
33
+
34
+ class Video(File):
35
+ pass
36
+
37
+ def make_list(obj: Any) -> List:
38
+ return obj if isinstance(obj, list) else [obj]
39
+
40
+
41
+ def _extract_image(image: Union[Image, PIL.Image.Image]) -> PIL.Image.Image:
42
+ if isinstance(image, Image):
43
+ if image.path.startswith("http://") or image.path.startswith("https://"):
44
+ image = PIL.Image.open(requests.get(image.path, stream=True).raw)
45
+ else:
46
+ image = PIL.Image.open(image.path)
47
+ return image
48
+
49
+
50
+ def _load_video(video_path: str, *, num_frames: int) -> List[PIL.Image.Image]:
51
+ # Load video frames from a directory
52
+ if os.path.isdir(video_path):
53
+ frame_paths = sorted(glob.glob(os.path.join(video_path, "*")))
54
+ indices = np.round(np.linspace(0, len(frame_paths) - 1, num_frames)).astype(int)
55
+ return [PIL.Image.open(frame_paths[index]) for index in indices]
56
+
57
+ # Load video frames from a video file
58
+ vidcap = cv2.VideoCapture(video_path)
59
+
60
+ # Find the last frame as frame count might not be accurate
61
+ frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
62
+ while frame_count > 0:
63
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1)
64
+ if vidcap.grab():
65
+ break
66
+ frame_count -= 1
67
+ else:
68
+ raise ValueError(f"Video '{video_path}' has no frames.")
69
+
70
+ # Extract frames uniformly
71
+ indices = np.round(np.linspace(0, frame_count - 1, num_frames)).astype(int)
72
+ frames = {}
73
+ for index in indices:
74
+ if index in frames:
75
+ continue
76
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, index)
77
+ success, frame = vidcap.read()
78
+ if not success:
79
+ print(f"Failed to read frame {index} from video '{video_path}'. Skipped.")
80
+ continue
81
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
82
+ frames[index] = PIL.Image.fromarray(frame)
83
+ return [frames[index] for index in indices if index in frames]
84
+
85
+
86
+ def _extract_video(video: Video, config: PretrainedConfig) -> List[PIL.Image.Image]:
87
+ num_frames = config.num_video_frames
88
+ if getattr(config, "fps") != 0:
89
+ print("Extracting frames from video with specified FPS is not supported yet. Ignored.")
90
+
91
+ frames = _load_video(video.path, num_frames=num_frames)
92
+ return frames
93
+
94
+
95
+ def extract_media(
96
+ messages: List[Dict[str, Any]],
97
+ config: Optional[PretrainedConfig] = None,
98
+ draft: bool = False,
99
+ ) -> Dict[str, List[Any]]:
100
+ media = defaultdict(list)
101
+ for message in messages:
102
+ text = ""
103
+ for part in make_list(message["value"]):
104
+ if isinstance(part, str):
105
+ for token in MEDIA_TOKENS.values():
106
+ if token in part:
107
+ print(f"Media token '{token}' found in text: '{part}'. Removed.")
108
+ part = part.replace(token, "").strip()
109
+ text += part
110
+ elif isinstance(part, (Image, PIL.Image.Image)):
111
+ if draft:
112
+ media["image"].append(part)
113
+ else:
114
+ media["image"].append(_extract_image(part))
115
+ text += MEDIA_TOKENS["image"]
116
+ elif isinstance(part, Video):
117
+ if draft:
118
+ media["video"].append(part)
119
+ else:
120
+ media["video"].append(_extract_video(part, config))
121
+ text += MEDIA_TOKENS["video"]
122
+ else:
123
+ raise ValueError(f"Unsupported prompt part type: {type(part)}")
124
+ message["value"] = text
125
+ return media
media_encoder.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from functools import partial
4
+ from typing import Any, Dict, List, Optional
5
+
6
+
7
+ class BaseEncoder(nn.Module):
8
+ def __init__(self, parent: nn.Module) -> None:
9
+ super().__init__()
10
+ self._parent = [parent]
11
+
12
+ @property
13
+ def parent(self) -> nn.Module:
14
+ return self._parent[0]
15
+
16
+
17
+ class BasicImageEncoder(BaseEncoder):
18
+ def __init__(
19
+ self,
20
+ parent: torch.nn.Module,
21
+ start_tokens: Optional[str] = None,
22
+ end_tokens: Optional[str] = "\n",
23
+ ) -> None:
24
+ super().__init__(parent)
25
+ self.start_tokens = start_tokens
26
+ self.end_tokens = end_tokens
27
+
28
+ def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]:
29
+ if tokens is None:
30
+ return None
31
+ token_ids = self.parent.tokenizer(tokens).input_ids
32
+ token_ids = torch.tensor(token_ids, device=self.parent.device)
33
+ return self.parent.llm.model.embed_tokens(token_ids)
34
+
35
+ def _process_features(
36
+ self,
37
+ features: torch.Tensor,
38
+ start_token_embeds: Optional[torch.Tensor],
39
+ end_token_embeds: Optional[torch.Tensor],
40
+ ) -> torch.Tensor:
41
+ if start_token_embeds is not None:
42
+ features = torch.cat([start_token_embeds, features], dim=0)
43
+ if end_token_embeds is not None:
44
+ features = torch.cat([features, end_token_embeds], dim=0)
45
+ return features
46
+
47
+ def forward(self, images: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]:
48
+ images = torch.stack(images, dim=0)
49
+ features = self.parent.encode_images(images, block_sizes=config.get("block_sizes"))
50
+ process_features = partial(
51
+ self._process_features,
52
+ start_token_embeds=self.embed_tokens(self.start_tokens),
53
+ end_token_embeds=self.embed_tokens(self.end_tokens),
54
+ )
55
+ return [process_features(f) for f in features]
56
+
57
+
58
+ class BasicVideoEncoder(BaseEncoder):
59
+ def __init__(
60
+ self,
61
+ parent: torch.nn.Module,
62
+ start_tokens: Optional[str] = None,
63
+ end_tokens: Optional[str] = "\n",
64
+ ) -> None:
65
+ super().__init__(parent)
66
+ self.start_tokens = start_tokens
67
+ self.end_tokens = end_tokens
68
+
69
+ def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]:
70
+ if tokens is None:
71
+ return None
72
+ token_ids = self.parent.tokenizer(tokens).input_ids
73
+ token_ids = torch.tensor(token_ids, device=self.parent.device)
74
+ return self.parent.llm.model.embed_tokens(token_ids)
75
+
76
+ def _process_features(
77
+ self,
78
+ features: torch.Tensor,
79
+ start_token_embeds: Optional[torch.Tensor],
80
+ end_token_embeds: Optional[torch.Tensor],
81
+ ) -> torch.Tensor:
82
+ if start_token_embeds is not None:
83
+ start_embeds = torch.stack([start_token_embeds] * features.shape[0], dim=0)
84
+ features = torch.cat([start_embeds, features], dim=1)
85
+ if end_token_embeds is not None:
86
+ end_embeds = torch.stack([end_token_embeds] * features.shape[0], dim=0)
87
+ features = torch.cat([features, end_embeds], dim=1)
88
+ return features.flatten(0, 1)
89
+
90
+ def forward(self, videos: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]:
91
+ num_frames = [video.shape[0] for video in videos]
92
+ images = torch.cat(videos, dim=0)
93
+ features = self.parent.encode_images(images)
94
+ features = torch.split(features, num_frames)
95
+ process_features = partial(
96
+ self._process_features,
97
+ start_token_embeds=self.embed_tokens(self.start_tokens),
98
+ end_token_embeds=self.embed_tokens(self.end_tokens),
99
+ )
100
+ return [process_features(f) for f in features]
mm_projector/config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "runs/train/qwen25-7b-3x3-sft-20241115225329/model/mm_projector",
3
+ "architectures": [
4
+ "MultimodalProjector"
5
+ ],
6
+ "mm_projector_type": "mlp_downsample_3x3_fix",
7
+ "model_type": "v2l_projector",
8
+ "torch_dtype": "bfloat16",
9
+ "transformers_version": "4.46.0"
10
+ }
mm_projector/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be36150c27b7f9af417f1f1584353565ff68d6b43518d10da03417292ab36e4a
3
+ size 122203760
mm_utils.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ # dynamic_preprocess and find_closest_aspect_ratio are referenced from https://github.com/OpenGVLab/InternVL
18
+
19
+ import base64
20
+ import os
21
+ import tempfile
22
+ from io import BytesIO
23
+
24
+ import numpy as np
25
+ import torch
26
+ from PIL import Image
27
+ from transformers import StoppingCriteria
28
+
29
+ from llava.constants import DEFAULT_IMAGE_TOKEN
30
+
31
+
32
+ def get_frame_from_vcap(vidcap, num_frames=10, max_fps=0.0, fps=None, frame_count=None, video_file_name=None):
33
+ import cv2
34
+
35
+ if fps == None or frame_count == None:
36
+ # if one of fps or frame_count is None, still recompute
37
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
38
+ frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
39
+ if fps == 0 or frame_count == 0:
40
+ print(f"Video file not found. return empty images. {video_file_name}")
41
+ return [
42
+ Image.new("RGB", (720, 720)),
43
+ ] * num_frames, 0
44
+
45
+ duration = frame_count / fps
46
+ frame_interval = frame_count // num_frames
47
+ if frame_interval == 0 and frame_count <= 1:
48
+ print(f"frame_interval is equal to 0. return empty image. {video_file_name}")
49
+ return [
50
+ Image.new("RGB", (720, 720)),
51
+ ] * num_frames, 0
52
+ # print("duration:", duration, "frames:", frame_count, "intervals:", frame_interval)
53
+
54
+ images = []
55
+ count = 0
56
+ success = True
57
+ frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int)
58
+ while success:
59
+ # print("frame_count:", frame_count, "count:", count, "num_frames:", num_frames, "frame_interval:", frame_interval)
60
+ if frame_count >= num_frames:
61
+ success, frame = vidcap.read()
62
+ if count in frame_indices:
63
+ try:
64
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
65
+ im_pil = Image.fromarray(img)
66
+ images.append(im_pil)
67
+ except BaseException:
68
+ continue
69
+ if len(images) >= num_frames:
70
+ return images, num_frames
71
+ count += 1
72
+ else:
73
+ # Left padding frames if the video is not long enough
74
+ success, frame = vidcap.read()
75
+ if success:
76
+ try:
77
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
78
+ im_pil = Image.fromarray(img)
79
+ images.append(im_pil)
80
+ except BaseException:
81
+ continue
82
+ count += 1
83
+ else:
84
+ break
85
+ if len(images) == 0:
86
+ raise ValueError("Did not find enough frames in the video. return empty image.")
87
+
88
+ return images, len(images)
89
+
90
+
91
+ def get_frame_from_vcap_with_fps(vidcap, num_frames=10, max_fps=0.0, fps=None, frame_count=None, video_file_name=None):
92
+ """
93
+ num_frames is the max number of frames the model can support.
94
+ frame_count is the number of frames in the input video.
95
+ max_fps is the max FPS of the model can support.
96
+ fps is the fps of the input video.
97
+ """
98
+
99
+ import random
100
+
101
+ import cv2
102
+
103
+ if fps == None or frame_count == None:
104
+ # if one of fps or frame_count is None, still recompute
105
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
106
+ frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
107
+
108
+ if fps == 0 or frame_count == 0:
109
+ print(f"Video file not found. return empty images. {video_file_name}")
110
+ empty_video_frames = int(random.uniform(2, 8 * max_fps))
111
+ return [
112
+ Image.new("RGB", (720, 720)),
113
+ ] * empty_video_frames, 0
114
+
115
+ duration = frame_count / fps
116
+ # print("duration:", duration, "frames:", frame_count, "fps:", fps, "num_frames:", num_frames, "max_fps:", max_fps)
117
+ # If the video is too long (longer than max_fps and num_frames can support),
118
+ # we will use lower fps to sample frames.
119
+ if duration >= num_frames / max_fps:
120
+ frame_interval = frame_count // num_frames
121
+
122
+ # If the video is too short, we will skip the video if there is only one frame.
123
+ if frame_interval == 0 and frame_count <= 1:
124
+ print(f"frame_interval is equal to 0. return empty image. {video_file_name}")
125
+ empty_video_frames = int(random.uniform(2, 8 * max_fps))
126
+ return [
127
+ Image.new("RGB", (720, 720)),
128
+ ] * empty_video_frames, 0
129
+
130
+ images = []
131
+ count = 0
132
+ success = True
133
+ frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int)
134
+
135
+ while success:
136
+ if frame_count >= num_frames:
137
+ # success, frame = vidcap.read()
138
+ if count in frame_indices:
139
+ success, frame = vidcap.read()
140
+ try:
141
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
142
+ im_pil = Image.fromarray(img)
143
+ images.append(im_pil)
144
+ except:
145
+ # print("Failed to read frame:", count)
146
+ continue
147
+ if len(images) >= num_frames:
148
+ return images, num_frames
149
+ else:
150
+ success = vidcap.grab()
151
+ count += 1
152
+ else:
153
+ # Left padding frames if the video is not long enough
154
+ success, frame = vidcap.read()
155
+ if success:
156
+ try:
157
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
158
+ im_pil = Image.fromarray(img)
159
+ images.append(im_pil)
160
+ except:
161
+ # print("Failed to read frame:", count)
162
+ continue
163
+ count += 1
164
+ else:
165
+ break
166
+ else:
167
+ frames_required = int(duration * max_fps)
168
+ frame_indices = np.linspace(0, frame_count - 1, frames_required, dtype=int)
169
+ if frames_required == 0:
170
+ print(f"frames_required is fewer than 2. Duration {duration}, return empty image.")
171
+ empty_video_frames = int(random.uniform(2, 8 * max_fps))
172
+ return [
173
+ Image.new("RGB", (720, 720)),
174
+ ] * empty_video_frames, 0
175
+ elif frames_required == 1:
176
+ frame_indices = np.linspace(0, frame_count - 1, 2, dtype=int)
177
+ images = []
178
+ count = 0
179
+ looked = 0
180
+ success = True
181
+
182
+ while success:
183
+ success, frame = vidcap.read()
184
+ if success and (looked in frame_indices):
185
+ try:
186
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
187
+ im_pil = Image.fromarray(img)
188
+ images.append(im_pil)
189
+ except:
190
+ continue
191
+ count += 1
192
+ looked += 1
193
+
194
+ if len(images) == 0:
195
+ empty_video_frames = int(random.uniform(2, 8 * max_fps))
196
+ return [
197
+ Image.new("RGB", (720, 720)),
198
+ ] * empty_video_frames, 0
199
+ else:
200
+ return images, len(images)
201
+
202
+
203
+ def opencv_extract_frames(vpath_or_bytesio, frames=6, max_fps=0.0, fps=None, frame_count=None):
204
+ """
205
+ Extract frames from a video using OpenCV.
206
+
207
+ Args:
208
+ vpath_or_bytesio (str or BytesIO): Path to the video file or BytesIO object containing the video.
209
+ frames (int): Number of frames to extract from the video.
210
+ fps (float): Frames per second of the video. If 0.0, the function will extract frames at equal intervals.
211
+
212
+ Returns:
213
+ list: List of PIL Images extracted from the video.
214
+
215
+ Raises:
216
+ NotImplementedError: If the type of `vpath_or_bytesio` is not supported.
217
+ """
218
+ import cv2
219
+
220
+ if isinstance(vpath_or_bytesio, str):
221
+ vidcap = cv2.VideoCapture(vpath_or_bytesio)
222
+ if max_fps > 0.0:
223
+ return get_frame_from_vcap_with_fps(
224
+ vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=vpath_or_bytesio
225
+ )
226
+ return get_frame_from_vcap(
227
+ vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=vpath_or_bytesio
228
+ )
229
+ elif isinstance(vpath_or_bytesio, (BytesIO,)):
230
+ # assuming mp4
231
+ with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video:
232
+ temp_video.write(vpath_or_bytesio.read())
233
+ temp_video_name = temp_video.name
234
+ vidcap = cv2.VideoCapture(temp_video_name)
235
+ if max_fps > 0.0:
236
+ return get_frame_from_vcap_with_fps(
237
+ vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=temp_video_name
238
+ )
239
+ return get_frame_from_vcap(
240
+ vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=temp_video_name
241
+ )
242
+ else:
243
+ raise NotImplementedError(type(vpath_or_bytesio))
244
+
245
+
246
+ def load_image_from_base64(image):
247
+ return Image.open(BytesIO(base64.b64decode(image)))
248
+
249
+
250
+ def expand2square(pil_img, background_color):
251
+ """
252
+ Expand the given PIL image to a square shape by adding padding.
253
+
254
+ Parameters:
255
+ - pil_img: The PIL image to be expanded.
256
+ - background_color: The color of the padding to be added.
257
+
258
+ Returns:
259
+ - The expanded PIL image.
260
+
261
+ If the image is already square, it is returned as is.
262
+ If the image is wider than it is tall, padding is added to the top and bottom.
263
+ If the image is taller than it is wide, padding is added to the left and right.
264
+ """
265
+ width, height = pil_img.size
266
+ if pil_img.mode == "L":
267
+ background_color = background_color[0]
268
+ if width == height:
269
+ return pil_img
270
+ elif width > height:
271
+ result = Image.new(pil_img.mode, (width, width), background_color)
272
+ result.paste(pil_img, (0, (width - height) // 2))
273
+ return result
274
+ else:
275
+ result = Image.new(pil_img.mode, (height, height), background_color)
276
+ result.paste(pil_img, ((height - width) // 2, 0))
277
+ return result
278
+
279
+
280
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
281
+ best_ratio_diff = float("inf")
282
+ best_ratio = (1, 1)
283
+ area = width * height
284
+ for ratio in target_ratios:
285
+ target_aspect_ratio = ratio[0] / ratio[1]
286
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
287
+ if ratio_diff < best_ratio_diff:
288
+ best_ratio_diff = ratio_diff
289
+ best_ratio = ratio
290
+ elif ratio_diff == best_ratio_diff:
291
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
292
+ best_ratio = ratio
293
+ return best_ratio
294
+
295
+
296
+ def dynamic_preprocess(image, min_num=1, max_num=12, image_size=384, use_thumbnail=True):
297
+ orig_width, orig_height = image.size
298
+ aspect_ratio = orig_width / orig_height
299
+
300
+ # calculate the existing image aspect ratio
301
+ target_ratios = {
302
+ (i, j)
303
+ for n in range(min_num, max_num + 1)
304
+ for i in range(1, n + 1)
305
+ for j in range(1, n + 1)
306
+ if i * j <= max_num and i * j >= min_num
307
+ }
308
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
309
+
310
+ # find the closest aspect ratio to the target
311
+ target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)
312
+
313
+ # calculate the target width and height
314
+ target_width = image_size * target_aspect_ratio[0]
315
+ target_height = image_size * target_aspect_ratio[1]
316
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
317
+
318
+ # resize the image
319
+ resized_img = image.resize((target_width, target_height))
320
+ processed_images = []
321
+ for i in range(blocks):
322
+ box = (
323
+ (i % (target_width // image_size)) * image_size,
324
+ (i // (target_width // image_size)) * image_size,
325
+ ((i % (target_width // image_size)) + 1) * image_size,
326
+ ((i // (target_width // image_size)) + 1) * image_size,
327
+ )
328
+ # split the image
329
+ split_img = resized_img.crop(box)
330
+ processed_images.append(split_img)
331
+ assert len(processed_images) == blocks
332
+ if use_thumbnail and len(processed_images) != 1:
333
+ thumbnail_img = image.resize((image_size, image_size))
334
+ processed_images.append(thumbnail_img)
335
+ return processed_images
336
+
337
+
338
+ def dynamic_s2_preprocess(image, s2_scales=[384, 768, 1152], max_num=12, image_size=384):
339
+ orig_width, orig_height = image.size
340
+ aspect_ratio = orig_width / orig_height
341
+ min_num = (s2_scales[-1] // s2_scales[0]) ** 2 # at least use number of tiles as the largest scale
342
+
343
+ processed_images = []
344
+
345
+ ##########################################################################################
346
+ ############# Add tiles for all but the last scale using fixed squre ratio ###############
347
+ ##########################################################################################
348
+
349
+ for scale in s2_scales[:-1]:
350
+ target_width = image_size * (scale // s2_scales[0])
351
+ target_height = image_size * (scale // s2_scales[0])
352
+ blocks = (scale // s2_scales[0]) ** 2
353
+
354
+ # resize the image
355
+ resized_img = image.resize((target_width, target_height))
356
+ for i in range(blocks):
357
+ box = (
358
+ (i % (target_width // image_size)) * image_size,
359
+ (i // (target_width // image_size)) * image_size,
360
+ ((i % (target_width // image_size)) + 1) * image_size,
361
+ ((i // (target_width // image_size)) + 1) * image_size,
362
+ )
363
+ # split the image
364
+ split_img = resized_img.crop(box)
365
+ processed_images.append(split_img)
366
+
367
+ ##########################################################################################
368
+ ################ Add tiles for the last scale using dynamic aspect ratio #################
369
+ ##########################################################################################
370
+
371
+ # calculate the existing image aspect ratio
372
+ target_ratios = {
373
+ (i, j)
374
+ for n in range(min_num, max_num + 1)
375
+ for i in range(1, n + 1)
376
+ for j in range(1, n + 1)
377
+ if i * j <= max_num and i * j >= min_num
378
+ }
379
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
380
+
381
+ # find the closest aspect ratio to the target
382
+ target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)
383
+
384
+ # calculate the target width and height
385
+ target_width = image_size * target_aspect_ratio[0]
386
+ target_height = image_size * target_aspect_ratio[1]
387
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
388
+
389
+ # resize the image
390
+ resized_img = image.resize((target_width, target_height))
391
+ for i in range(blocks):
392
+ box = (
393
+ (i % (target_width // image_size)) * image_size,
394
+ (i // (target_width // image_size)) * image_size,
395
+ ((i % (target_width // image_size)) + 1) * image_size,
396
+ ((i // (target_width // image_size)) + 1) * image_size,
397
+ )
398
+ # split the image
399
+ split_img = resized_img.crop(box)
400
+ processed_images.append(split_img)
401
+
402
+ return processed_images, (target_aspect_ratio[1], target_aspect_ratio[0])
403
+
404
+
405
+ def dynamic_process_images_and_prompt(images, prompt, data_args, image_folder=None, max_tiles=None):
406
+ prompt = prompt.split(DEFAULT_IMAGE_TOKEN)
407
+ idx = 0
408
+ all_images = []
409
+ for img in images:
410
+ processed_images = process_image(img, data_args, image_folder, enable_dynamic_res=True, max_tiles=max_tiles)
411
+ all_images.append(processed_images)
412
+ prompt.insert(idx + 1, f"{DEFAULT_IMAGE_TOKEN}\n" * processed_images.shape[0])
413
+ idx += 2
414
+ prompt = "".join(prompt)
415
+ if all_images:
416
+ all_images = torch.cat(all_images)
417
+ else:
418
+ all_images = None
419
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, "")
420
+ return all_images, prompt
421
+
422
+
423
+ def dynamic_s2_process_images_and_prompt(images, prompt, data_args, image_folder=None):
424
+ idx = 0
425
+ all_images = []
426
+ all_block_size = []
427
+ for img in images:
428
+ processed_images, block_size = process_image(img, data_args, image_folder, enable_dynamic_s2=True)
429
+ all_images.append(processed_images)
430
+ all_block_size.append(block_size)
431
+ idx += 2
432
+ if all_images:
433
+ all_images = torch.cat(all_images)
434
+ else:
435
+ all_images = None
436
+ return all_images, all_block_size
437
+
438
+
439
+ def process_image(
440
+ image_file, data_args, image_folder, enable_dynamic_res=False, enable_dynamic_s2=False, max_tiles=None
441
+ ):
442
+ processor = data_args.image_processor
443
+ if isinstance(image_file, str):
444
+ if image_folder is not None:
445
+ image = Image.open(os.path.join(image_folder, image_file)).convert("RGB")
446
+ else:
447
+ image = Image.open(image_file).convert("RGB")
448
+ else:
449
+ # image is stored in bytearray
450
+ image = image_file
451
+ image = image.convert("RGB")
452
+ if hasattr(data_args.image_processor, "crop_size"):
453
+ # CLIP vision tower
454
+ crop_size = data_args.image_processor.crop_size
455
+ else:
456
+ # SIGLIP vision tower
457
+ assert hasattr(data_args.image_processor, "size")
458
+ crop_size = data_args.image_processor.size
459
+ if "dynamic_s2" in data_args.image_aspect_ratio and enable_dynamic_s2:
460
+ assert crop_size["height"] == crop_size["width"]
461
+ images, block_size = dynamic_s2_preprocess(
462
+ image, s2_scales=data_args.s2_scales, max_num=data_args.max_tiles, image_size=crop_size["height"]
463
+ )
464
+ images = [processor.preprocess(image, return_tensors="pt")["pixel_values"][0] for image in images]
465
+ return torch.stack(images), block_size
466
+ if "dynamic" in data_args.image_aspect_ratio and enable_dynamic_res:
467
+ assert crop_size["height"] == crop_size["width"]
468
+ if max_tiles is not None:
469
+ max_num = max_tiles
470
+ else:
471
+ max_num = data_args.max_tiles
472
+ images = dynamic_preprocess(image, min_num=data_args.min_tiles, max_num=max_num, image_size=crop_size["height"])
473
+ images = [processor.preprocess(image, return_tensors="pt")["pixel_values"][0] for image in images]
474
+ return torch.stack(images)
475
+
476
+ if data_args.image_aspect_ratio == "resize":
477
+ image = image.resize((crop_size["width"], crop_size["height"]))
478
+ if data_args.image_aspect_ratio == "pad":
479
+
480
+ def expand2square(pil_img, background_color):
481
+ width, height = pil_img.size
482
+ if width == height:
483
+ return pil_img
484
+ elif width > height:
485
+ result = Image.new(pil_img.mode, (width, width), background_color)
486
+ result.paste(pil_img, (0, (width - height) // 2))
487
+ return result
488
+ else:
489
+ result = Image.new(pil_img.mode, (height, height), background_color)
490
+ result.paste(pil_img, ((height - width) // 2, 0))
491
+ return result
492
+
493
+ image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
494
+ image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
495
+ else:
496
+ # Using default behavior of the vision encoder
497
+ # For CLIP, default is central crop
498
+ # For Radio, default is central crop
499
+ # For Siglip, default is resize
500
+ # For InternVIT, default is resize
501
+ image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
502
+ return image
503
+
504
+
505
+ def process_images(images, image_processor, model_cfg, enable_dynamic_res=False, max_tiles=None):
506
+ model_cfg.image_processor = image_processor
507
+ new_images = [
508
+ process_image(image, model_cfg, None, enable_dynamic_res=enable_dynamic_res, max_tiles=max_tiles)
509
+ for image in images
510
+ ]
511
+
512
+ if all(x.shape == new_images[0].shape for x in new_images):
513
+ if len(new_images[0].shape) == 4:
514
+ new_images = torch.cat(new_images, dim=0)
515
+ elif len(new_images[0].shape) == 3:
516
+ new_images = torch.stack(new_images, dim=0)
517
+ else:
518
+ raise ValueError(f"new_images rank does not equal to 4, rank: {len(new_images[0].shape)}")
519
+ else:
520
+ raise ValueError("The shape of images in new_images is different!")
521
+ return new_images
522
+
523
+
524
+ def tokenizer_image_token(prompt, tokenizer, return_tensors=None):
525
+ return tokenizer(prompt, return_tensors=return_tensors).input_ids[0]
526
+
527
+
528
+ def is_gemma_tokenizer(tokenizer):
529
+ return "gemma" in tokenizer.__class__.__name__.lower()
530
+
531
+
532
+ def get_model_name_from_path(model_path):
533
+ model_path = model_path.strip("/")
534
+ model_paths = model_path.split("/")
535
+ if model_paths[-1].startswith("checkpoint-"):
536
+ return model_paths[-2] + "_" + model_paths[-1]
537
+ else:
538
+ return model_paths[-1]
539
+
540
+
541
+ class KeywordsStoppingCriteria(StoppingCriteria):
542
+ def __init__(self, keywords, tokenizer, input_ids):
543
+ self.keywords = keywords
544
+ self.keyword_ids = []
545
+ self.max_keyword_len = 0
546
+ for keyword in keywords:
547
+ cur_keyword_ids = tokenizer(keyword).input_ids
548
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
549
+ cur_keyword_ids = cur_keyword_ids[1:]
550
+ if len(cur_keyword_ids) > self.max_keyword_len:
551
+ self.max_keyword_len = len(cur_keyword_ids)
552
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
553
+ self.tokenizer = tokenizer
554
+ self.start_len = input_ids.shape[1]
555
+
556
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
557
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
558
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
559
+ for keyword_id in self.keyword_ids:
560
+ if (output_ids[0, -keyword_id.shape[0] :] == keyword_id).all():
561
+ return True
562
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
563
+ for keyword in self.keywords:
564
+ if keyword in outputs:
565
+ return True
566
+ return False
567
+
568
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
569
+ outputs = []
570
+ for i in range(output_ids.shape[0]):
571
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
572
+ return all(outputs)
modeling_vila.py ADDED
@@ -0,0 +1,1024 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ import copy
3
+ import json
4
+ import logging
5
+ import math
6
+ import os
7
+ import os.path
8
+ import os.path as osp
9
+ import warnings
10
+ from abc import ABC
11
+ from collections import OrderedDict, defaultdict, deque
12
+ from copy import deepcopy
13
+ from itertools import chain
14
+ from threading import Thread
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.distributed as dist
20
+ import torch.nn.functional as F
21
+ import torchvision
22
+ from einops import rearrange
23
+ from PIL import Image
24
+
25
+ from transformers import (
26
+ AutoConfig,
27
+ AutoModel,
28
+ AutoProcessor,
29
+ AutoTokenizer,
30
+ GenerationConfig,
31
+ LogitsProcessor,
32
+ PretrainedConfig,
33
+ PreTrainedModel,
34
+ Qwen2Config,
35
+ Qwen2ForCausalLM,
36
+ Qwen2PreTrainedModel,
37
+ TextIteratorStreamer
38
+ )
39
+ from transformers.modeling_utils import ContextManagers, no_init_weights
40
+ from transformers.modeling_outputs import CausalLMOutputWithPast
41
+
42
+ from .base_projector import MultimodalProjector, MultimodalProjectorConfig
43
+ from .builder import build_llm_and_tokenizer
44
+ from .configuration_vila import VILAConfig
45
+ from .media_encoder import BasicImageEncoder, BasicVideoEncoder
46
+ from .siglip_encoder import SiglipVisionTower, SiglipVisionTowerDynamicS2, SiglipVisionTowerS2
47
+ from .utils import get_model_config
48
+ from .media import extract_media
49
+ from .mm_utils import process_image, process_images
50
+ from .tokenizer_utils import tokenize_conversation
51
+ from .constants import *
52
+ from .conversation import default_conversation, SeparatorStyle
53
+
54
+ # from llava.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, NUM_EXTRA_TOKENS
55
+ # quick hack for remote code
56
+ def get_pg_manager():
57
+ return None
58
+
59
+ def get_model_weights_dtype(model: nn.Module):
60
+ pass
61
+
62
+
63
+ def build_mm_projector(model_type_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
64
+ if model_type_or_path is None:
65
+ return None
66
+ ## load from pretrained model
67
+ if config.resume_path:
68
+ assert os.path.exists(model_type_or_path), f"Resume mm projector path {model_type_or_path} does not exist!"
69
+ return MultimodalProjector.from_pretrained(model_type_or_path, config)
70
+ ## build from scratch
71
+ else:
72
+ mm_projector_cfg = MultimodalProjectorConfig(model_type_or_path)
73
+ mm_projector = MultimodalProjector(mm_projector_cfg, config)
74
+ return mm_projector
75
+
76
+
77
+ def build_vision_tower(model_name_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
78
+ ## skip vision tower instantiation
79
+ if model_name_or_path is None:
80
+ return None
81
+
82
+ vision_tower_arch = None
83
+ if config.resume_path and "radio" not in model_name_or_path:
84
+ assert os.path.exists(model_name_or_path), f"Resume vision tower path {model_name_or_path} does not exist!"
85
+ vision_tower_cfg = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
86
+ vision_tower_arch = vision_tower_cfg.architectures[0].lower()
87
+ vision_tower_name = vision_tower_arch if vision_tower_arch is not None else model_name_or_path
88
+
89
+ use_s2 = getattr(config, "s2", False)
90
+ use_dynamic_s2 = getattr(config, "dynamic_s2", False)
91
+
92
+ if "siglip" in vision_tower_name:
93
+ if use_dynamic_s2:
94
+ vision_tower = SiglipVisionTowerDynamicS2(model_name_or_path, config)
95
+ elif use_s2:
96
+ vision_tower = SiglipVisionTowerS2(model_name_or_path, config)
97
+ else:
98
+ vision_tower = SiglipVisionTower(model_name_or_path, config)
99
+ else:
100
+ raise NotImplementedError(f"Unknown vision tower: {model_name_or_path}")
101
+
102
+ config.mm_hidden_size = (
103
+ vision_tower.config.hidden_size if not (use_s2 or use_dynamic_s2) else vision_tower.hidden_size
104
+ )
105
+ return vision_tower
106
+
107
+
108
+ class VILAPretrainedModel(PreTrainedModel):
109
+ config_class = VILAConfig
110
+ main_input_name = "input_embeds"
111
+ supports_gradient_checkpointing = True
112
+ _supports_flash_attn_2 = True
113
+
114
+ def __init__(self, config: VILAConfig, *args, **kwargs):
115
+ super().__init__(config)
116
+ self.config = config
117
+ cfgs = get_model_config(config)
118
+ if len(cfgs) == 3:
119
+ llm_cfg, vision_tower_cfg, mm_projector_cfg = cfgs
120
+ else:
121
+ raise ValueError("`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config.")
122
+
123
+ # loading on cpu by default
124
+ device_map = kwargs.get("device_map", "cpu")
125
+ self.mm_projector = build_mm_projector(mm_projector_cfg, config)
126
+ self.vision_tower = build_vision_tower(vision_tower_cfg, config)
127
+ if "auto" in device_map or "cuda" in device_map:
128
+ self.mm_projector = self.mm_projector.cuda()
129
+ self.vision_tower = self.vision_tower.cuda()
130
+ # set device_map auto can autoamtically shard llm to different devices
131
+ self.llm, self.tokenizer = self.init_llm(llm_cfg, config, device_map=device_map)
132
+
133
+ self.encoders = {
134
+ "image": BasicImageEncoder(self),
135
+ "video": BasicVideoEncoder(self)
136
+ }
137
+
138
+ self.post_config()
139
+ self.is_loaded = True
140
+
141
+ assert (
142
+ self.llm is not None or self.vision_tower is not None or self.mm_projector is not None
143
+ ), "At least one of the components must be instantiated."
144
+
145
+ @classmethod
146
+ def convert_vila_dev_ckpt_to_remote(self, model_path: str, output_dir:str = None, *model_args, **kwargs):
147
+ # assert type(self) == VILAForCasualLM, "This method is only available for VILAForCasualLM."
148
+ from huggingface_hub import HfApi, snapshot_download
149
+
150
+ if os.path.isdir(model_path):
151
+ model_path = model_path
152
+ api = HfApi()
153
+ if api.repo_exists(model_path):
154
+ model_path = snapshot_download(model_path, local_dir=output_dir)
155
+ print("downloading HF model to", model_path)
156
+
157
+ cfg_path = os.path.join(model_path, "config.json")
158
+ config = json.load(open(cfg_path))
159
+ config["version"] = "2.0" # nvila tag
160
+ config["architectures"] = ["VILAForCasualLM"]
161
+ config["auto_map"] = {
162
+ "AutoConfig": "modeling_vila.VILAConfig",
163
+ "AutoModel": "modeling_vila.VILAForCasualLM",
164
+ "AutoModelForCausalLM": "modeling_vila.VILAForCasualLM"
165
+ }
166
+ config["model_type"] = "vila"
167
+ json.dump(config, open(cfg_path, "w"), indent=2)
168
+ self.copy_remote_py_files(model_path)
169
+
170
+ @classmethod
171
+ def copy_remote_py_files(cls, output_dir):
172
+ ## copy .py and REAMDE for next loading remote code
173
+ current_file_path = os.path.abspath(__file__)
174
+ current_folder = os.path.dirname(current_file_path)
175
+ for file_name in os.listdir(current_folder):
176
+ if file_name.endswith(".py"):
177
+ full_file_name = os.path.join(current_folder, file_name)
178
+ if os.path.isfile(full_file_name):
179
+ shutil.copy(full_file_name, output_dir)
180
+ print("[HF remote code] copying", full_file_name, "to", output_dir)
181
+
182
+ def save_pretrained(self, output_dir, state_dict=None):
183
+ if state_dict is None:
184
+ # other wise fetch from deepspeed
185
+ # state_dict = accelerator.get_state_dict(is_deepspeed_enabled)
186
+ state_dict = self.state_dict()
187
+
188
+ if getattr(self, "tokenizer", None):
189
+ self.tokenizer.save_pretrained(osp.join(output_dir, "llm"))
190
+
191
+ if self.get_llm():
192
+ print(f"saving llm to {osp.join(output_dir, 'llm')}")
193
+ self.llm.config._name_or_path = osp.join(output_dir, "llm")
194
+ llm_state_dict = OrderedDict({k.split("llm.")[-1]: v for k, v in state_dict.items() if "llm" in k})
195
+ self.llm.save_pretrained(os.path.join(output_dir, "llm"), state_dict=llm_state_dict)
196
+ self.config.llm_cfg = self.llm.config
197
+
198
+ if self.get_vision_tower():
199
+ print(f"saving vision_tower to {osp.join(output_dir, 'vision_tower')}")
200
+ self.vision_tower.config._name_or_path = osp.join(output_dir, "vision_tower")
201
+ vision_tower_state_dict = OrderedDict(
202
+ {k.split("vision_tower.vision_tower.")[-1]: v for k, v in state_dict.items() if "vision_tower" in k}
203
+ )
204
+ self.vision_tower.vision_tower.save_pretrained(
205
+ os.path.join(output_dir, "vision_tower"),
206
+ state_dict=vision_tower_state_dict,
207
+ )
208
+ self.vision_tower.image_processor.save_pretrained(os.path.join(output_dir, "vision_tower"))
209
+ self.config.vision_tower_cfg = self.vision_tower.config
210
+ if hasattr(self.config.vision_tower_cfg, "auto_map"):
211
+ if "radio" not in self.get_vision_tower().__class__.__name__.lower():
212
+ delattr(self.config.vision_tower_cfg, "auto_map")
213
+
214
+ if self.get_mm_projector():
215
+ print(f"saving mm_projector to {osp.join(output_dir, 'mm_projector')}")
216
+ self.mm_projector.config._name_or_path = osp.join(output_dir, "mm_projector")
217
+ mm_projector_state_dict = OrderedDict(
218
+ {k.split("mm_projector.")[-1]: v for k, v in state_dict.items() if "mm_projector" in k}
219
+ )
220
+ self.mm_projector.save_pretrained(
221
+ os.path.join(output_dir, "mm_projector"),
222
+ state_dict=mm_projector_state_dict,
223
+ )
224
+ self.config.mm_projector_cfg = self.mm_projector.config
225
+
226
+ ## update and save top-level config
227
+ self.config._name_or_path = output_dir
228
+ self.config.architectures = [self.__class__.__name__]
229
+ self.config.save_pretrained(output_dir)
230
+
231
+ ## copy .py and REAMDE for next loading remote code
232
+ self.copy_remote_py_files(output_dir)
233
+
234
+
235
+
236
+ @classmethod
237
+ def from_pretrained(
238
+ cls,
239
+ pretrained_model_name_or_path: Optional[str] = None,
240
+ *model_args,
241
+ config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
242
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
243
+ ignore_mismatched_sizes: bool = False,
244
+ force_download: bool = False,
245
+ local_files_only: bool = False,
246
+ token: Optional[Union[str, bool]] = None,
247
+ revision: str = "main",
248
+ use_safetensors: Optional[bool] = None,
249
+ weights_only: bool = True,
250
+ **kwargs,
251
+ ):
252
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
253
+ return cls._from_config(config, **kwargs)
254
+
255
+ def init_llm(self, llm_config, config, *args, **kwargs):
256
+ self.llm, self.tokenizer = build_llm_and_tokenizer(llm_config, config, *args, **kwargs)
257
+ # hard coded for NVILA
258
+ # variables for XGrammar
259
+ # print("DEBUG", len(self.tokenizer.added_tokens_encoder.keys()), self.tokenizer.added_tokens_encoder.keys())
260
+ NUM_EXTRA_TOKENS = len(self.tokenizer.added_tokens_encoder.keys())
261
+
262
+ # TODO: SENTINEL_TOKEN is not added, need to check with Zhijian
263
+ self.vocab_size = self.tokenizer.vocab_size + NUM_EXTRA_TOKENS
264
+ # XGrammar tokenizer and grammar compiler
265
+ # lazy init only when specified json output during inference
266
+ self.grammar_compiler = None
267
+
268
+ self.llm.resize_token_embeddings(len(self.tokenizer))
269
+ return self.llm, self.tokenizer
270
+
271
+ def post_config(self):
272
+ ######################################################################
273
+ # TODO: need to check dtype with jason
274
+ self.llm = self.llm.to(torch.float16)
275
+ self.mm_projector = self.mm_projector.to(torch.float16)
276
+ self.vision_tower = self.vision_tower.to(torch.float16)
277
+ ######################################################################
278
+ self.training = self.llm.training
279
+ ## configuration
280
+ if getattr(self.config, "llm_cfg", None) is None:
281
+ self.config.llm_cfg = self.llm.config
282
+ if getattr(self.config, "vision_tower_cfg", None) is None:
283
+ self.config.vision_tower_cfg = self.vision_tower.config
284
+ if getattr(self.config, "mm_projector_cfg", None) is None:
285
+ self.config.mm_projector_cfg = self.mm_projector.config
286
+
287
+ def get_llm(self):
288
+ llm = getattr(self, "llm", None)
289
+ if type(llm) is list:
290
+ llm = llm[0]
291
+ return llm
292
+
293
+ def get_lm_head(self):
294
+ lm_head = getattr(self.get_llm(), "lm_head", None)
295
+ return lm_head
296
+
297
+ def get_vision_tower(self):
298
+ vision_tower = getattr(self, "vision_tower", None)
299
+ if type(vision_tower) is list:
300
+ vision_tower = vision_tower[0]
301
+ return vision_tower
302
+
303
+ def get_mm_projector(self):
304
+ mm_projector = getattr(self, "mm_projector", None)
305
+ if type(mm_projector) is list:
306
+ mm_projector = mm_projector[0]
307
+ return mm_projector
308
+
309
+ def freezed_module_patch(self):
310
+ """
311
+ Huggingface will call model.train() at each training_step. To ensure the expected behaviors for modules like dropout, batchnorm, etc., we need to call model.eval() for the freezed modules.
312
+ """
313
+ if self.training:
314
+ if self.get_llm() and not getattr(self.config, "tune_language_model", False):
315
+ pass
316
+ # logging.warning("Caution: Your LLM is currently in training mode, ensuring accurate gradient computation. Please be vigilant, particularly regarding BatchNorm and Dropout operations.")
317
+ if self.get_vision_tower() and not getattr(self.config, "tune_vision_tower", False):
318
+ self.get_vision_tower().eval()
319
+ if self.get_mm_projector() and not getattr(self.config, "tune_mm_projector", False):
320
+ self.get_mm_projector().eval()
321
+
322
+ class VILAForCasualLM(VILAPretrainedModel):
323
+ def __init__(self, config: VILAConfig, *args, **kwargs):
324
+ super().__init__(config, *args, **kwargs)
325
+
326
+ def merge_features_for_dynamic_s2(self, image_features, block_sizes):
327
+ scales = self.get_vision_tower().scales
328
+ resize_output_to_scale_idx = self.get_vision_tower().resize_output_to_scale_idx
329
+
330
+ image_features_each_image = []
331
+ new_block_sizes = []
332
+ block_cnt = 0
333
+ for block_size_each_image in block_sizes:
334
+ if block_size_each_image is None:
335
+ cur_features = image_features[block_cnt : block_cnt + 1]
336
+ cur_features = rearrange(cur_features, "1 (h w) c -> 1 c h w", h=int(cur_features.shape[1] ** 0.5))
337
+ cur_features = cur_features.repeat(1, len(scales), 1, 1)
338
+ image_features_each_image.append(cur_features)
339
+ new_block_sizes.append((1, 1))
340
+ block_cnt += 1
341
+ else:
342
+ cur_features_each_scale = []
343
+ for scale in scales[:-1]:
344
+ num_blocks_this_scale = (scale // scales[0]) ** 2
345
+ cur_features_each_scale.append(
346
+ self.merge_chessboard(
347
+ image_features[block_cnt : block_cnt + num_blocks_this_scale],
348
+ num_split_h=scale // scales[0],
349
+ num_split_w=scale // scales[0],
350
+ )
351
+ ) # 1 * C * H * W
352
+ block_cnt += num_blocks_this_scale
353
+ num_blocks_last_scale = block_size_each_image[0] * block_size_each_image[1]
354
+ cur_features_each_scale.append(
355
+ self.merge_chessboard(
356
+ image_features[block_cnt : block_cnt + num_blocks_last_scale],
357
+ num_split_h=block_size_each_image[0],
358
+ num_split_w=block_size_each_image[1],
359
+ )
360
+ ) # 1 * C * H * W
361
+ block_cnt += num_blocks_last_scale
362
+
363
+ # resize and concat features from different scales
364
+ output_size = cur_features_each_scale[resize_output_to_scale_idx].shape[-2:]
365
+ cur_features = torch.cat(
366
+ [
367
+ F.interpolate(cur_features_each_scale[i].to(torch.float32), size=output_size, mode="area").to(
368
+ cur_features_each_scale[i].dtype
369
+ )
370
+ for i in range(len(cur_features_each_scale))
371
+ ],
372
+ dim=1,
373
+ )
374
+ # cur_features = rearrange(cur_features, "1 c h w -> (h w) c")
375
+
376
+ image_features_each_image.append(cur_features)
377
+
378
+ if resize_output_to_scale_idx == len(scales) - 1 or resize_output_to_scale_idx == -1:
379
+ new_block_sizes.append(block_size_each_image)
380
+ else:
381
+ new_block_sizes.append(
382
+ (
383
+ scales[resize_output_to_scale_idx] // scales[0],
384
+ scales[resize_output_to_scale_idx] // scales[0],
385
+ )
386
+ )
387
+
388
+ assert block_cnt == len(image_features)
389
+
390
+ return image_features_each_image, new_block_sizes
391
+
392
+ def encode_images(self, images, block_sizes: Optional[Optional[Tuple[int, ...]]] = None):
393
+ if block_sizes is None:
394
+ block_sizes = [None] * len(images)
395
+ if getattr(self.config, "dynamic_s2", False):
396
+ image_features = self.get_vision_tower()(images)
397
+ image_features, new_block_sizes = self.merge_features_for_dynamic_s2(image_features, block_sizes)
398
+
399
+ image_features = [
400
+ self.split_chessboard(x, block_size[0], block_size[1])
401
+ for x, block_size in zip(image_features, new_block_sizes)
402
+ ] # list of B * C * H * W tensors
403
+ image_features = torch.cat(
404
+ [rearrange(x, "b c h w -> b (h w) c") for x in image_features], dim=0
405
+ ) # B * N * C
406
+ image_features = self.get_mm_projector()(image_features)
407
+ image_features = list(
408
+ image_features.split([block_size[0] * block_size[1] for block_size in new_block_sizes], dim=0)
409
+ )
410
+ image_features = [
411
+ self.merge_chessboard(x, block_size[0], block_size[1])
412
+ for x, block_size in zip(image_features, new_block_sizes)
413
+ ] # list of 1 * C * H * W tensors
414
+ image_features = [rearrange(x, "1 c h w -> (h w) c") for x in image_features] # list of N * C tensors
415
+ if all([feature.shape[0] == image_features[0].shape[0] for feature in image_features]):
416
+ image_features = torch.stack(image_features, dim=0)
417
+ else:
418
+ image_features = self.get_vision_tower()(images)
419
+ image_features = self.get_mm_projector()(image_features)
420
+ return image_features
421
+
422
+ def _embed(
423
+ self,
424
+ input_ids: torch.Tensor,
425
+ media: Dict[str, List[torch.Tensor]],
426
+ media_config: Dict[str, Dict[str, Any]],
427
+ labels: Optional[torch.Tensor],
428
+ attention_mask: Optional[torch.Tensor],
429
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
430
+ labels = labels if labels is not None else torch.full_like(input_ids, IGNORE_INDEX)
431
+ attention_mask = attention_mask if attention_mask is not None else torch.ones_like(input_ids, dtype=torch.bool)
432
+
433
+ # PROCESS_GROUP_MANAGER = get_pg_manager()
434
+ PROCESS_GROUP_MANAGER = None
435
+ if PROCESS_GROUP_MANAGER is not None:
436
+ for name in media:
437
+ self.encoders[name].end_tokens = None
438
+
439
+ # Extract text and media embeddings
440
+ text_embeds = self.llm.model.embed_tokens(input_ids)
441
+ media_embeds = self.__embed_media_tokens(media, media_config)
442
+
443
+ # This is a workaround to make sure the dummy embeddings are consumed
444
+ while media_embeds.get("dummy"):
445
+ dummy_embed = media_embeds["dummy"].popleft()
446
+ text_embeds += torch.sum(dummy_embed) * 0
447
+
448
+ # Remove padding
449
+ batch_size = labels.shape[0]
450
+ text_embeds = [text_embeds[k][attention_mask[k]] for k in range(batch_size)]
451
+ labels = [labels[k][attention_mask[k]] for k in range(batch_size)]
452
+
453
+ # Build inverse mapping from token ID to media name
454
+ media_tokens = {}
455
+ for name, token_id in self.tokenizer.media_token_ids.items():
456
+ media_tokens[token_id] = name
457
+
458
+ # Fuse text and media embeddings
459
+ inputs_m, labels_m = [], []
460
+ for k in range(batch_size):
461
+ inputs_mk, labels_mk = [], []
462
+ pos = 0
463
+ while pos < len(labels[k]):
464
+ if input_ids[k][pos].item() in media_tokens:
465
+ end = pos + 1
466
+ name = media_tokens[input_ids[k][pos].item()]
467
+ input = media_embeds[name].popleft()
468
+ label = torch.full([input.shape[0]], IGNORE_INDEX, device=labels[k].device, dtype=labels[k].dtype)
469
+ else:
470
+ end = pos
471
+ while end < len(labels[k]) and input_ids[k][end].item() not in media_tokens:
472
+ end += 1
473
+ input = text_embeds[k][pos:end]
474
+ label = labels[k][pos:end]
475
+ inputs_mk.append(input)
476
+ labels_mk.append(label)
477
+ pos = end
478
+ inputs_m.append(torch.cat(inputs_mk, dim=0))
479
+ labels_m.append(torch.cat(labels_mk, dim=0))
480
+ inputs, labels = inputs_m, labels_m
481
+
482
+ # Check if all media embeddings are consumed
483
+ for name in media_embeds:
484
+ if media_embeds[name]:
485
+ raise ValueError(f"Not all {name} embeddings are consumed!")
486
+
487
+ # Truncate sequences to `model_max_length` as media embeddings are inserted
488
+ inputs, labels = self.__truncate_sequence(inputs, labels)
489
+
490
+ # Pad sequences to the longest one in the batch
491
+ return self.__batchify_sequence(inputs, labels)
492
+
493
+ def __embed_media_tokens(
494
+ self,
495
+ media: Dict[str, List[torch.Tensor]],
496
+ media_config: Dict[str, Dict[str, Any]],
497
+ ) -> Dict[str, List[torch.Tensor]]:
498
+ embeds = defaultdict(deque)
499
+ for name in media:
500
+ if self.training:
501
+ # Gather metainfo of media objects from all ranks
502
+ info = [{"shape": tensor.shape, "dtype": tensor.dtype} for tensor in media.get(name, [])]
503
+ infos = list(chain(*distributed.all_gather(info)))
504
+
505
+ # The entire batch does not contain any media objects of this type.
506
+ if not infos:
507
+ continue
508
+
509
+ # Create a dummy tensor to ensure the encoder is called, otherwise the training will hang.
510
+ if media.get(name) is None or len(media[name]) == 0:
511
+ dummy = torch.zeros(infos[0]["shape"], dtype=infos[0]["dtype"], device=self.device)
512
+ embeds["dummy"].extend(self.encoders[name]([dummy], media_config[name]))
513
+ continue
514
+ embeds[name] = deque(self.encoders[name](media[name], media_config[name]))
515
+ return embeds
516
+
517
+ def __truncate_sequence(
518
+ self, inputs: List[torch.Tensor], labels: List[torch.Tensor]
519
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
520
+ if self.training and any(len(input) > self.tokenizer.model_max_length for input in inputs):
521
+ warnings.warn(f"Truncating sequences to `model_max_length` ({self.tokenizer.model_max_length}).")
522
+ inputs = [input[: self.tokenizer.model_max_length] for input in inputs]
523
+ labels = [label[: self.tokenizer.model_max_length] for label in labels]
524
+ return inputs, labels
525
+
526
+ def __batchify_sequence(
527
+ self, inputs: List[torch.Tensor], labels: List[torch.Tensor]
528
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
529
+ batch_size = len(inputs)
530
+ device = inputs[0].device
531
+ hidden_size = inputs[0].shape[1]
532
+ max_length = max(inputs[k].shape[0] for k in range(batch_size))
533
+ attention_mask = torch.ones((batch_size, max_length), dtype=torch.bool, device=device)
534
+
535
+ inputs_p, labels_p = [], []
536
+ for k in range(batch_size):
537
+ size_pk = max_length - inputs[k].shape[0]
538
+ inputs_pk = torch.zeros((size_pk, hidden_size), dtype=inputs[k].dtype, device=device)
539
+ labels_pk = torch.full((size_pk,), IGNORE_INDEX, dtype=labels[k].dtype, device=device)
540
+ if self.tokenizer.padding_side == "right":
541
+ attention_mask[k, inputs[k].shape[0] :] = False
542
+ inputs_pk = torch.cat([inputs[k], inputs_pk], dim=0)
543
+ labels_pk = torch.cat([labels[k], labels_pk], dim=0)
544
+ else:
545
+ attention_mask[k, : -inputs[k].shape[0]] = False
546
+ inputs_pk = torch.cat([inputs_pk, inputs[k]], dim=0)
547
+ labels_pk = torch.cat([labels_pk, labels[k]], dim=0)
548
+ inputs_p.append(inputs_pk)
549
+ labels_p.append(labels_pk)
550
+
551
+ inputs = torch.stack(inputs_p, dim=0)
552
+ labels = torch.stack(labels_p, dim=0)
553
+ return inputs, labels, attention_mask
554
+
555
+ def repack_multimodal_data(self, inputs_embeds, attention_mask, position_ids, labels):
556
+ # Handle sequence parallelism
557
+ PROCESS_GROUP_MANAGER = get_pg_manager()
558
+
559
+ # We do re-sharding instead of packing here to ensure the sequence length is the same across all ranks.
560
+ if PROCESS_GROUP_MANAGER is not None:
561
+ sp_degree = PROCESS_GROUP_MANAGER.sp_degree
562
+ sp_rank = PROCESS_GROUP_MANAGER.sp_rank
563
+ sp_group = PROCESS_GROUP_MANAGER.sp_pg
564
+ ring_degree = PROCESS_GROUP_MANAGER.ring_degree
565
+ ring_rank = PROCESS_GROUP_MANAGER.ring_rank
566
+ ring_type = PROCESS_GROUP_MANAGER.ring_type
567
+ ulysses_degree = PROCESS_GROUP_MANAGER.ulysses_degree
568
+ ulysses_rank = PROCESS_GROUP_MANAGER.ulysses_rank
569
+
570
+ bs, shard_seqlen = position_ids.shape
571
+ sp_seq_len = [torch.zeros(1, dtype=torch.int64, device=position_ids.device) for _ in range(sp_degree)]
572
+ dist.all_gather(sp_seq_len, torch.tensor(shard_seqlen, device=position_ids.device), group=sp_group)
573
+ sp_seq_len_cat = torch.cat(sp_seq_len, dim=0)
574
+
575
+ if sp_rank == 0:
576
+ original_start_id = 0
577
+ else:
578
+ original_start_id = torch.sum(sp_seq_len_cat[:sp_rank]).item()
579
+ original_end_id = torch.sum(sp_seq_len_cat[: sp_rank + 1]).item()
580
+
581
+ # Gather attention_mask, position_ids, labels and input_embeds
582
+ all_inputs_embeds = torch.zeros(
583
+ bs,
584
+ torch.sum(sp_seq_len_cat),
585
+ inputs_embeds.shape[-1],
586
+ dtype=inputs_embeds.dtype,
587
+ device=inputs_embeds.device,
588
+ ).contiguous()
589
+ all_inputs_embeds[:, original_start_id:original_end_id, :] += inputs_embeds
590
+ dist.barrier(group=sp_group)
591
+ dist.all_reduce(all_inputs_embeds, group=sp_group)
592
+ dist.barrier(group=sp_group)
593
+
594
+ attention_mask_list = [
595
+ torch.zeros((bs, sp_seq_len[i]), dtype=attention_mask.dtype, device=attention_mask.device)
596
+ for i in range(sp_degree)
597
+ ]
598
+ position_ids_list = [
599
+ torch.zeros((bs, sp_seq_len[i]), dtype=position_ids.dtype, device=position_ids.device)
600
+ for i in range(sp_degree)
601
+ ]
602
+ labels_list = [
603
+ torch.zeros((bs, sp_seq_len[i]), dtype=labels.dtype, device=labels.device) for i in range(sp_degree)
604
+ ]
605
+
606
+ dist.all_gather(attention_mask_list, attention_mask, group=sp_group)
607
+ dist.all_gather(position_ids_list, position_ids, group=sp_group)
608
+ dist.all_gather(labels_list, labels, group=sp_group)
609
+
610
+ effective_seqlen_list = [attention_mask_list[i].sum(dim=-1) for i in range(sp_degree)]
611
+ effective_seqlen = torch.stack(effective_seqlen_list, dim=-1)
612
+ effective_seqlen_batch_list = torch.unbind(effective_seqlen, dim=0)
613
+
614
+ global_attention_mask_list = []
615
+ global_position_ids_list = []
616
+ global_labels_list = []
617
+ global_inputs_embeds_list = []
618
+ for i in range(bs):
619
+ global_attention_mask_batch_list = []
620
+ global_position_ids_batch_list = []
621
+ global_labels_batch_list = []
622
+ global_inputs_embeds_batch_list = []
623
+ for j in range(sp_degree):
624
+ eff_len = effective_seqlen_batch_list[i][j]
625
+ prev_len = torch.sum(sp_seq_len_cat[:j]).item() if j > 0 else 0
626
+
627
+ global_attention_mask_batch_list.append(attention_mask_list[j][i, :eff_len])
628
+ global_position_ids_batch_list.append(position_ids_list[j][i, :eff_len])
629
+ global_labels_batch_list.append(labels_list[j][i, :eff_len])
630
+ global_inputs_embeds_batch_list.append(all_inputs_embeds[i, prev_len : prev_len + eff_len, :])
631
+ global_attention_mask_list.append(torch.cat(global_attention_mask_batch_list, dim=0))
632
+ global_position_ids_list.append(torch.cat(global_position_ids_batch_list, dim=0))
633
+ global_labels_list.append(torch.cat(global_labels_batch_list, dim=0))
634
+ global_inputs_embeds_list.append(torch.cat(global_inputs_embeds_batch_list, dim=0))
635
+
636
+ global_attention_mask = torch.nn.utils.rnn.pad_sequence(
637
+ global_attention_mask_list, batch_first=True, padding_value=False
638
+ )
639
+ global_position_ids = torch.nn.utils.rnn.pad_sequence(
640
+ global_position_ids_list, batch_first=True, padding_value=-1
641
+ )
642
+ global_labels = torch.nn.utils.rnn.pad_sequence(
643
+ global_labels_list, batch_first=True, padding_value=IGNORE_INDEX
644
+ )
645
+ global_inputs_embeds = torch.nn.utils.rnn.pad_sequence(
646
+ global_inputs_embeds_list, batch_first=True, padding_value=0
647
+ )
648
+
649
+ # Re-shard the inputs
650
+ if ring_degree > 1:
651
+ total_effective_seqlen = torch.sum(effective_seqlen, dim=1)
652
+ new_seqlen_per_rank = total_effective_seqlen // sp_degree
653
+ assert torch.all(
654
+ total_effective_seqlen % sp_degree == 0
655
+ ), "total_effective_seqlen must be divisible by sp_degree"
656
+
657
+ max_new_seqlen = torch.max(new_seqlen_per_rank).item()
658
+
659
+ new_attention_mask = torch.zeros(
660
+ (bs, max_new_seqlen), dtype=global_attention_mask.dtype, device=global_attention_mask.device
661
+ )
662
+ new_position_ids = torch.zeros(
663
+ (bs, max_new_seqlen), dtype=global_position_ids.dtype, device=global_position_ids.device
664
+ )
665
+ new_labels = torch.full(
666
+ (bs, max_new_seqlen), IGNORE_INDEX, dtype=global_labels.dtype, device=global_labels.device
667
+ )
668
+ new_inputs_embeds = torch.zeros(
669
+ (bs, max_new_seqlen, global_inputs_embeds.shape[-1]),
670
+ dtype=global_inputs_embeds.dtype,
671
+ device=global_inputs_embeds.device,
672
+ )
673
+
674
+ if ring_type == "ring_varlen":
675
+ for i in range(bs):
676
+ start_idx = new_seqlen_per_rank[i] * sp_rank
677
+ end_idx = start_idx + new_seqlen_per_rank[i]
678
+ new_attention_mask[i, : new_seqlen_per_rank[i]] = global_attention_mask[i, start_idx:end_idx]
679
+ new_position_ids[i, : new_seqlen_per_rank[i]] = global_position_ids[i, start_idx:end_idx]
680
+ new_labels[i, : new_seqlen_per_rank[i]] = global_labels[i, start_idx:end_idx]
681
+ new_inputs_embeds[i, : new_seqlen_per_rank[i], :] = global_inputs_embeds[
682
+ i, start_idx:end_idx, :
683
+ ]
684
+ elif ring_type == "zigzag_ring_varlen":
685
+ chunk_size = total_effective_seqlen // (2 * sp_degree)
686
+ for i in range(bs):
687
+ # Zigzag pattern indices
688
+ if sp_degree == ring_degree:
689
+ forward_rank_idx = sp_rank
690
+ backward_rank_idx = 2 * sp_degree - sp_rank - 1
691
+ else:
692
+ ulysses_offset = ulysses_rank * ring_degree * 2
693
+ forward_rank_idx = ring_rank + ulysses_offset
694
+ backward_rank_idx = sp_degree - ring_rank - 1 + ulysses_offset
695
+
696
+ # Calculate start and end indices for the forward and backward zigzag
697
+ start_idx_fwd = forward_rank_idx * chunk_size[i]
698
+ end_idx_fwd = start_idx_fwd + chunk_size[i]
699
+
700
+ start_idx_bwd = backward_rank_idx * chunk_size[i]
701
+ end_idx_bwd = start_idx_bwd + chunk_size[i]
702
+
703
+ # Fill new tensors with zigzag data
704
+ new_attention_mask[i, : chunk_size[i]] = global_attention_mask[i, start_idx_fwd:end_idx_fwd]
705
+ new_attention_mask[i, chunk_size[i] : 2 * chunk_size[i]] = global_attention_mask[
706
+ i, start_idx_bwd:end_idx_bwd
707
+ ]
708
+
709
+ new_position_ids[i, : chunk_size[i]] = global_position_ids[i, start_idx_fwd:end_idx_fwd]
710
+ new_position_ids[i, chunk_size[i] : 2 * chunk_size[i]] = global_position_ids[
711
+ i, start_idx_bwd:end_idx_bwd
712
+ ]
713
+
714
+ new_labels[i, : chunk_size[i]] = global_labels[i, start_idx_fwd:end_idx_fwd]
715
+ new_labels[i, chunk_size[i] : 2 * chunk_size[i]] = global_labels[i, start_idx_bwd:end_idx_bwd]
716
+
717
+ new_inputs_embeds[i, : chunk_size[i], :] = global_inputs_embeds[i, start_idx_fwd:end_idx_fwd, :]
718
+ new_inputs_embeds[i, chunk_size[i] : 2 * chunk_size[i], :] = global_inputs_embeds[
719
+ i, start_idx_bwd:end_idx_bwd, :
720
+ ]
721
+ else:
722
+ raise ValueError(f"Invalid ring_type: {ring_type}")
723
+ else:
724
+ global_seq_len = global_attention_mask.shape[-1]
725
+ seq_len_sharded = global_seq_len // sp_degree
726
+ start_idx_reshard = seq_len_sharded * sp_rank
727
+ end_idx_reshard = start_idx_reshard + seq_len_sharded if sp_rank < sp_degree - 1 else global_seq_len
728
+
729
+ new_attention_mask = torch.narrow(
730
+ global_attention_mask, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard
731
+ )
732
+ new_position_ids = torch.narrow(
733
+ global_position_ids, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard
734
+ )
735
+ new_labels = torch.narrow(global_labels, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard)
736
+ new_inputs_embeds = torch.narrow(
737
+ global_inputs_embeds, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard
738
+ )
739
+
740
+ return new_inputs_embeds, new_attention_mask, new_position_ids, new_labels
741
+
742
+ device = inputs_embeds.device
743
+ batch_size = inputs_embeds.shape[0]
744
+ seqlens = [attention_mask[k].sum().item() for k in range(batch_size)]
745
+
746
+ # Pack all sequences together
747
+ inputs_embeds_p = [inputs_embeds[k][attention_mask[k]] for k in range(batch_size)]
748
+ attention_mask_p = [torch.ones(seqlens[k], dtype=torch.int, device=device) for k in range(batch_size)]
749
+ position_ids_p = [torch.arange(seqlens[k], dtype=torch.int, device=device) for k in range(batch_size)]
750
+ labels_p = [labels[k][attention_mask[k]] for k in range(batch_size)]
751
+
752
+ # Add one dummy token at the end of the packed sequence to ensure that `_get_unpacked_data` will be called
753
+ inputs_embeds_p.append(torch.zeros(1, inputs_embeds.shape[-1], dtype=inputs_embeds.dtype, device=device))
754
+ attention_mask_p.append(torch.tensor([0], dtype=torch.int, device=device))
755
+ position_ids_p.append(torch.tensor([0], dtype=torch.int, device=device))
756
+ labels_p.append(torch.tensor([IGNORE_INDEX], dtype=torch.int, device=device))
757
+
758
+ # Mask the first token of each sequence to avoid contamination
759
+ for label in labels_p:
760
+ label[0] = IGNORE_INDEX
761
+
762
+ # Batch the data
763
+ inputs_embeds_p = torch.cat(inputs_embeds_p, dim=0).unsqueeze(0)
764
+ attention_mask_p = torch.cat(attention_mask_p, dim=0).unsqueeze(0)
765
+ position_ids_p = torch.cat(position_ids_p, dim=0).unsqueeze(0)
766
+ labels_p = torch.cat(labels_p, dim=0).unsqueeze(0)
767
+
768
+ if hasattr(
769
+ self, "pad_to_multiple_of"
770
+ ): # related to quantization, please refer to ModelArguments for more information.
771
+ assert len(labels_p.shape) == 2
772
+ batch_size, max_length, cur_length = labels_p.shape[0], labels_p.shape[1], labels_p.shape[1]
773
+ hidden_size = inputs_embeds_p.shape[-1]
774
+
775
+ if max_length % self.pad_to_multiple_of != 0:
776
+ max_length = ((max_length // self.pad_to_multiple_of) + 1) * self.pad_to_multiple_of
777
+ difference = max_length - cur_length
778
+
779
+ inputs_embeds_p = torch.cat(
780
+ (
781
+ inputs_embeds_p,
782
+ torch.full((batch_size, difference, hidden_size), self.llm.pad_token_id).to(inputs_embeds_p),
783
+ ),
784
+ dim=1,
785
+ )
786
+ labels_p = torch.cat((labels_p, torch.full((batch_size, difference), IGNORE_INDEX).to(labels_p)), dim=1)
787
+ attention_mask_p = torch.cat(
788
+ (
789
+ attention_mask_p,
790
+ torch.zeros((batch_size, difference), dtype=torch.bool).to(attention_mask_p),
791
+ ),
792
+ dim=1,
793
+ )
794
+ position_ids_p = torch.cat(
795
+ (position_ids_p, torch.full((batch_size, difference), -1).to(position_ids_p)), dim=1
796
+ )
797
+
798
+ return inputs_embeds_p, attention_mask_p, position_ids_p, labels_p
799
+
800
+ def get_xgr_logits_processor(self, response_format) -> List[LogitsProcessor]:
801
+ raise NotImplementedError("This method is not implemented for VILA model.")
802
+ # Convert response format to logits processor
803
+ import xgrammar as xgr
804
+
805
+ logging.info("[XGrammar] Compiling grammar for contrained output")
806
+
807
+ if self.grammar_compiler is None:
808
+ # logging.info(f"[XGrammar] {self.tokenizer}, {self.tokenizer.vocab_size}, {self.vocab_size}")
809
+ self.grammar_compiler = xgr.GrammarCompiler(
810
+ xgr.TokenizerInfo.from_huggingface(self.tokenizer, vocab_size=self.vocab_size)
811
+ )
812
+
813
+ if response_format.type == "json_schema":
814
+ compiled_grammar = self.grammar_compiler.compile_json_schema(
815
+ response_format.json_schema.schema_,
816
+ indent=2,
817
+ )
818
+ else:
819
+ compiled_grammar = self.grammar_compiler.compile_builtin_json_grammar()
820
+
821
+ return [xgr.contrib.hf.LogitsProcessor(compiled_grammar)]
822
+
823
+ def forward(
824
+ self,
825
+ input_ids: torch.LongTensor = None,
826
+ media: Optional[Dict[str, List[torch.Tensor]]] = None,
827
+ images: Optional[torch.FloatTensor] = None,
828
+ media_config: Optional[List] = None,
829
+ attention_mask: Optional[torch.Tensor] = None,
830
+ position_ids: Optional[torch.LongTensor] = None,
831
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
832
+ inputs_embeds: Optional[torch.FloatTensor] = None,
833
+ labels: Optional[torch.LongTensor] = None,
834
+ packing: bool = True,
835
+ force_packing: bool = False,
836
+ seqlens_in_batch: Optional[torch.LongTensor] = None,
837
+ dpo_forward: bool = False,
838
+ **kwargs,
839
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
840
+ self.freezed_module_patch()
841
+
842
+ if images is not None:
843
+ if media is not None:
844
+ raise ValueError("Both 'media' and 'images' are provided. Please provide only one.")
845
+ print("The 'images' argument is deprecated. Please use 'media' instead.")
846
+ media = {"image": images}
847
+
848
+ if media_config is None:
849
+ media_config = defaultdict(dict)
850
+
851
+ if inputs_embeds is None:
852
+ inputs_embeds, labels, attention_mask = self._embed(input_ids, media, media_config, labels, attention_mask)
853
+
854
+ if force_packing or (packing and self.training and not dpo_forward):
855
+ if seqlens_in_batch is None:
856
+ seqlens_in_batch = torch.sum(attention_mask, dim=1)
857
+ set_seqlens_in_batch(seqlens_in_batch)
858
+
859
+ (inputs_embeds, attention_mask, position_ids, labels) = self.repack_multimodal_data(
860
+ inputs_embeds, attention_mask, position_ids, labels
861
+ )
862
+
863
+ outputs = self.llm(
864
+ inputs_embeds=inputs_embeds,
865
+ attention_mask=attention_mask,
866
+ position_ids=position_ids,
867
+ past_key_values=past_key_values,
868
+ labels=labels,
869
+ **kwargs,
870
+ )
871
+
872
+ if self.training and getattr(self.config, "time_token_ids", []):
873
+ outputs.loss = soft_cross_entropy(
874
+ outputs.logits,
875
+ labels,
876
+ soft_tokens=self.config.time_token_ids,
877
+ std=self.config.soft_ce_std,
878
+ )
879
+
880
+ if dpo_forward:
881
+ return outputs.logits, labels
882
+
883
+ return outputs
884
+ @torch.inference_mode()
885
+ def generate(
886
+ self,
887
+ input_ids: Optional[torch.FloatTensor] = None,
888
+ media: Optional[Dict[str, List[torch.Tensor]]] = None,
889
+ media_config: Dict[str, Dict[str, Any]] = None,
890
+ attention_mask: Optional[torch.LongTensor] = None,
891
+ **generation_kwargs,
892
+ ):
893
+ inputs_embeds, _, attention_mask = self._embed(input_ids, media, media_config, None, attention_mask)
894
+ return self.llm.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generation_kwargs)
895
+
896
+ @torch.inference_mode()
897
+ def generate_content(
898
+ self,
899
+ prompt: Union[str, List],
900
+ generation_config: Optional[GenerationConfig] = None,
901
+ response_format = None,
902
+ ) -> str:
903
+ # TODO(zhijianl): Support directly taking conversation as input
904
+ conversation = [{"from": "human", "value": prompt}]
905
+
906
+ # Convert response format to logits processor
907
+ if response_format:
908
+ xgr_logits_processor = self.get_xgr_logits_processor(response_format)
909
+ else:
910
+ xgr_logits_processor = None
911
+
912
+ # Extract media from the conversation
913
+
914
+ # TODO (extract and preprocess should be done together, as the preprocess of image and video can be different, i.e. when dynamic res is used)
915
+ media = extract_media(conversation, self.config)
916
+
917
+ # Process media
918
+ media_config = defaultdict(dict)
919
+ for name in media:
920
+ if name == "image":
921
+ if len(media["image"]) == 1 and self.config.image_aspect_ratio in ["dynamic", "dynamic_s2"]:
922
+ self.config.image_processor = self.vision_tower.image_processor
923
+ if self.config.image_aspect_ratio == "dynamic":
924
+ images = process_image(media["image"][0], self.config, None, enable_dynamic_res=True).half()
925
+ conversation[0]["value"] = conversation[0]["value"].replace(
926
+ DEFAULT_IMAGE_TOKEN, f"{DEFAULT_IMAGE_TOKEN}\n" * images.shape[0]
927
+ )
928
+ else:
929
+ if type(self.config.s2_scales) is str:
930
+ self.config.s2_scales = list(map(int, self.config.s2_scales.split(",")))
931
+ images, block_sizes = process_image(
932
+ media["image"][0], self.config, None, enable_dynamic_s2=True
933
+ )
934
+ images = images.half()
935
+ media_config[name]["block_sizes"] = [block_sizes]
936
+ else:
937
+ images = process_images(media["image"], self.vision_tower.image_processor, self.config).half()
938
+ media[name] = [image for image in images]
939
+ elif name == "video":
940
+ if self.config.image_aspect_ratio == "dynamic" and self.config.video_max_tiles > 1:
941
+ media[name] = [
942
+ process_images(
943
+ images,
944
+ self.vision_tower.image_processor,
945
+ self.config,
946
+ enable_dynamic_res=True,
947
+ max_tiles=self.config.video_max_tiles,
948
+ ).half()
949
+ for images in media[name]
950
+ ]
951
+ elif self.config.image_aspect_ratio == "dynamic_s2" and self.config.video_max_tiles > 1:
952
+ self.config.image_processor = self.vision_tower.image_processor
953
+ if type(self.config.s2_scales) is str:
954
+ self.config.s2_scales = list(map(int, self.config.s2_scales.split(",")))
955
+ media[name] = [
956
+ torch.cat(
957
+ [
958
+ process_image(
959
+ image,
960
+ self.config,
961
+ None,
962
+ enable_dynamic_s2=True,
963
+ max_tiles=self.config.video_max_tiles,
964
+ )[0].half()
965
+ for image in images
966
+ ]
967
+ )
968
+ for images in media[name]
969
+ ]
970
+ else:
971
+ media[name] = [
972
+ process_images(images, self.vision_tower.image_processor, self.config).half()
973
+ for images in media[name]
974
+ ]
975
+ else:
976
+ raise ValueError(f"Unsupported media type: {name}")
977
+
978
+ # Tokenize the conversation
979
+ input_ids = tokenize_conversation(conversation, self.tokenizer, add_generation_prompt=True).cuda().unsqueeze(0)
980
+
981
+ # Set up the generation config
982
+ generation_config = generation_config or self.default_generation_config
983
+
984
+ # Generate the response
985
+ try:
986
+ output_ids = self.generate(
987
+ input_ids=input_ids,
988
+ media=media,
989
+ media_config=media_config,
990
+ generation_config=generation_config,
991
+ logits_processor=xgr_logits_processor, # structured generation
992
+ )
993
+ except ValueError:
994
+ if not generation_config.do_sample:
995
+ raise
996
+ # FIXME(zhijianl): This is a temporary workaround for the sampling issue
997
+ logging.warning("Generation failed with sampling, retrying with greedy decoding.")
998
+ generation_config.do_sample = False
999
+ output_ids = self.generate(
1000
+ input_ids=input_ids,
1001
+ media=media,
1002
+ media_config=media_config,
1003
+ generation_config=generation_config,
1004
+ logits_processor=xgr_logits_processor,
1005
+ )
1006
+
1007
+ # Decode the response
1008
+ response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
1009
+ return response
1010
+
1011
+ @property
1012
+ def default_generation_config(self) -> GenerationConfig:
1013
+ generation_config = copy.deepcopy(self.generation_config or GenerationConfig())
1014
+ if self.tokenizer.eos_token_id is None:
1015
+ raise ValueError("Tokenizer must have an EOS token")
1016
+ if generation_config.max_length == GenerationConfig().max_length:
1017
+ generation_config.max_length = self.tokenizer.model_max_length
1018
+ if generation_config.pad_token_id is None:
1019
+ generation_config.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
1020
+ if generation_config.bos_token_id is None:
1021
+ generation_config.bos_token_id = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
1022
+ if generation_config.eos_token_id is None:
1023
+ generation_config.eos_token_id = self.tokenizer.eos_token_id
1024
+ return generation_config
siglip_encoder.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from accelerate.hooks import add_hook_to_module
21
+ from einops import rearrange
22
+ from s2wrapper import forward as multiscale_forward
23
+ from transformers import AutoConfig, PreTrainedModel
24
+ from transformers.image_processing_utils import BaseImageProcessor
25
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
26
+ from transformers.models.siglip import SiglipVisionModel
27
+ from transformers import PretrainedConfig, SiglipImageProcessor
28
+
29
+ class VisionTower(nn.Module):
30
+ def __init__(self, vision_tower, args, delay_load=False):
31
+ super().__init__()
32
+
33
+ self.is_loaded = False
34
+
35
+ self.vision_tower_name = vision_tower
36
+ self.select_layer = getattr(args, "mm_vision_select_layer", -2)
37
+ self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
38
+
39
+ self.cfg_only = None
40
+
41
+ def feature_select(self, image_forward_outs):
42
+ image_features = image_forward_outs.hidden_states[self.select_layer]
43
+ if self.select_feature == "patch":
44
+ image_features = image_features[:, 1:]
45
+ elif self.select_feature == "cls_patch":
46
+ image_features = image_features
47
+ else:
48
+ raise ValueError(f"Unexpected select feature: {self.select_feature}")
49
+ return image_features
50
+
51
+ def _maybe_resize_pos_embeds(
52
+ self,
53
+ model: PreTrainedModel,
54
+ image_processor: BaseImageProcessor,
55
+ resolution: int = -1,
56
+ interpolate_mode: str = "linear",
57
+ ):
58
+ if resolution in [model.config.image_size, -1]:
59
+ return
60
+ print(
61
+ f"Resizing vision model's position embeddings to support higher vision resolution: from {model.config.image_size} to {resolution} ..."
62
+ )
63
+ embeddings = model.vision_model.embeddings
64
+ patch_size = embeddings.patch_size
65
+ num_new_tokens = int((resolution // patch_size) ** 2)
66
+
67
+ old_embeddings = embeddings.position_embedding
68
+ match interpolate_mode:
69
+ case "linear":
70
+ ## Step 1: Calculate the corresponding patch ID (pid) in the current resolution (M patches) based on the target resolution (N patches). Formula: pid = pid / N * M
71
+ ## Step 2: Obtain new embeddings by interpolating between the embeddings of the two nearest calculated patch IDs. Formula: new_embeds = (pid - floor(pid)) * embeds[ceil(pid)] + (ceil(pid) - pid) * embeds[floor(pid)]
72
+ import torch
73
+ import torch.nn as nn
74
+
75
+ if is_deepspeed_zero3_enabled():
76
+ import deepspeed
77
+
78
+ with deepspeed.zero.GatheredParameters([old_embeddings.weight], modifier_rank=None):
79
+ old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
80
+ else:
81
+ old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
82
+ new_embeddings = nn.Embedding(
83
+ num_new_tokens,
84
+ old_embedding_dim,
85
+ dtype=old_embeddings.weight.dtype,
86
+ device=old_embeddings.weight.device,
87
+ )
88
+ mapped_indices = (
89
+ torch.arange(num_new_tokens).to(old_embeddings.weight.device)
90
+ / (num_new_tokens - 1)
91
+ * (old_num_tokens - 1)
92
+ )
93
+ floor_indices = torch.clamp(mapped_indices.floor().long(), min=0, max=old_num_tokens - 1)
94
+ ceil_indices = torch.clamp(mapped_indices.ceil().long(), min=0, max=old_num_tokens - 1)
95
+ if is_deepspeed_zero3_enabled():
96
+ params = [old_embeddings.weight, new_embeddings.weight]
97
+ with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
98
+ interpolated_embeds = (mapped_indices - floor_indices)[:, None] * old_embeddings.weight.data[
99
+ ceil_indices, :
100
+ ] + (ceil_indices - mapped_indices)[:, None] * old_embeddings.weight.data[floor_indices, :]
101
+ else:
102
+ interpolated_embeds = (mapped_indices - floor_indices)[:, None] * old_embeddings.weight.data[
103
+ ceil_indices, :
104
+ ] + (ceil_indices - mapped_indices)[:, None] * old_embeddings.weight.data[floor_indices, :]
105
+ new_embeddings.weight.data = interpolated_embeds
106
+ case _:
107
+ raise NotImplementedError
108
+
109
+ if hasattr(old_embeddings, "_hf_hook"):
110
+ hook = old_embeddings._hf_hook
111
+ add_hook_to_module(new_embeddings, hook)
112
+ new_embeddings.requires_grad_(old_embeddings.weight.requires_grad)
113
+ ## update vision encoder's configurations
114
+ model.config.image_size = resolution
115
+ if hasattr(image_processor, "crop_size"):
116
+ # CLIP vision tower
117
+ image_processor.crop_size = resolution
118
+ else:
119
+ # SIGLIP vision tower
120
+ assert hasattr(image_processor, "size")
121
+ image_processor.size = {"height": resolution, "width": resolution}
122
+ ## TODO define a '_reinitialize' method for VisionTower
123
+ embeddings.position_embedding = new_embeddings
124
+ embeddings.image_size = resolution
125
+ embeddings.num_patches = embeddings.num_positions = num_new_tokens
126
+ embeddings.position_ids = (
127
+ torch.arange(embeddings.num_positions).expand((1, -1)).to(old_embeddings.weight.device)
128
+ )
129
+
130
+ def forward(self, images):
131
+ if type(images) is list:
132
+ image_features = []
133
+ for image in images:
134
+ image_forward_out = self.vision_tower(
135
+ image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
136
+ output_hidden_states=True,
137
+ )
138
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
139
+ image_features.append(image_feature)
140
+ else:
141
+ image_forward_outs = self.vision_tower(
142
+ images.to(device=self.device, dtype=self.dtype),
143
+ output_hidden_states=True,
144
+ )
145
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
146
+
147
+ return image_features
148
+
149
+
150
+ @property
151
+ def dummy_feature(self):
152
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
153
+
154
+ @property
155
+ def dtype(self):
156
+ return self.vision_tower.dtype
157
+
158
+ @property
159
+ def device(self):
160
+ return self.vision_tower.device
161
+
162
+ @property
163
+ def config(self):
164
+ if self.is_loaded:
165
+ return self.vision_tower.config
166
+ else:
167
+ return self.cfg_only
168
+
169
+ @property
170
+ def hidden_size(self):
171
+ return self.config.hidden_size
172
+
173
+ @property
174
+ def num_patches(self):
175
+ return (self.config.image_size // self.config.patch_size) ** 2
176
+
177
+
178
+ class VisionTowerS2(VisionTower):
179
+ def __init__(self, vision_tower, args, delay_load=False):
180
+ super().__init__(vision_tower, args, delay_load)
181
+
182
+ self.scales = list(map(int, args.s2_scales.split(",")))
183
+ self.scales.sort()
184
+ self.max_split_size = args.s2_max_split_size
185
+ self.resize_output_to_scale_idx = getattr(args, "s2_resize_output_to_scale_idx", 0)
186
+
187
+ def forward_feature(self, images):
188
+ image_forward_outs = self.vision_tower(
189
+ images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
190
+ )
191
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
192
+ return image_features
193
+
194
+ def forward(self, images):
195
+ if type(images) is list:
196
+ image_feature = []
197
+ for image in images:
198
+ image_feature = multiscale_forward(
199
+ self.forward_feature,
200
+ image.unsqueeze(0),
201
+ img_sizes=self.scales,
202
+ max_split_size=self.max_split_size,
203
+ resize_output_to_idx=self.resize_output_to_scale_idx,
204
+ )
205
+ image_features.append(image_feature)
206
+ else:
207
+ image_features = multiscale_forward(
208
+ self.forward_feature,
209
+ images,
210
+ img_sizes=self.scales,
211
+ max_split_size=self.max_split_size,
212
+ resize_output_to_idx=self.resize_output_to_scale_idx,
213
+ )
214
+
215
+ return image_features
216
+
217
+ @property
218
+ def hidden_size(self):
219
+ return self.config.hidden_size * len(self.scales)
220
+
221
+
222
+ class VisionTowerDynamicS2(VisionTower):
223
+ def __init__(self, vision_tower, args, delay_load=False):
224
+ super().__init__(vision_tower, args, delay_load)
225
+
226
+ self.scales = list(map(int, args.s2_scales.split(",")))
227
+ self.scales.sort()
228
+ self.max_split_size = args.s2_max_split_size
229
+ self.resize_output_to_scale_idx = getattr(args, "s2_resize_output_to_scale_idx", 0)
230
+
231
+ def forward_feature(self, images):
232
+ image_forward_outs = self.vision_tower(
233
+ images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
234
+ )
235
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
236
+ return image_features
237
+
238
+ def forward(self, images):
239
+ assert type(images) is not list
240
+ image_features = self.forward_feature(images)
241
+
242
+ return image_features
243
+
244
+ @property
245
+ def hidden_size(self):
246
+ return self.config.hidden_size * len(self.scales)
247
+
248
+
249
+ class SiglipVisionTower(VisionTower):
250
+ def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None:
251
+ super().__init__(model_name_or_path, config)
252
+ # TODO(ligengl): why pass config here leading to errors?
253
+ self.vision_tower = SiglipVisionModel.from_pretrained(
254
+ model_name_or_path,
255
+ attn_implementation=config._attn_implementation,
256
+ torch_dtype=eval(config.model_dtype),
257
+ )
258
+ self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path)
259
+ self.is_loaded = True
260
+
261
+
262
+ class SiglipVisionTowerS2(VisionTowerS2):
263
+ def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None:
264
+ super().__init__(model_name_or_path, config)
265
+ self.vision_tower = SiglipVisionModel.from_pretrained(
266
+ model_name_or_path,
267
+ attn_implementation=config._attn_implementation,
268
+ torch_dtype=eval(config.model_dtype),
269
+ )
270
+ self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path)
271
+ # Make sure it crops/resizes the image to the largest scale in self.scales to maintain high-res information
272
+ self.image_processor.size["height"] = self.image_processor.size["width"] = self.scales[-1]
273
+ self.is_loaded = True
274
+
275
+
276
+ class SiglipVisionTowerDynamicS2(VisionTowerDynamicS2):
277
+ def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None:
278
+ super().__init__(model_name_or_path, config)
279
+ self.vision_tower = SiglipVisionModel.from_pretrained(
280
+ model_name_or_path,
281
+ attn_implementation="flash_attention_2",
282
+ torch_dtype=eval(config.model_dtype),
283
+ )
284
+ self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path)
285
+ # Make sure it crops/resizes the image to the largest scale in self.scales to maintain high-res information
286
+ self.image_processor.size["height"] = self.image_processor.size["width"] = self.scales[0]
287
+ self.is_loaded = True
tokenizer_utils.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ from typing import Any, Dict, List, Optional, Sequence
18
+
19
+ import torch
20
+ import transformers
21
+
22
+ from .conversation import default_conversation, SeparatorStyle
23
+ from .mm_utils import tokenizer_image_token
24
+ from .constants import IGNORE_INDEX, SENTINEL_TOKEN
25
+
26
+ # __all__ = [
27
+ # "tokenize_conversation",
28
+ # "preprocess_conversation",
29
+ # "infer_stop_tokens",
30
+ # ]
31
+
32
+ DUMMY_CONVERSATION = [
33
+ {"from": "human", "value": "question"},
34
+ {"from": "gpt", "value": "answer"},
35
+ ] * 10
36
+
37
+
38
+ def tokenize_conversation_legacy(
39
+ messages: Sequence[Dict[str, str]],
40
+ tokenizer: transformers.PreTrainedTokenizer,
41
+ add_generation_prompt: bool = False,
42
+ overrides: Optional[Dict[str, str]] = None,
43
+ no_system_prompt: bool = False,
44
+ ) -> torch.Tensor:
45
+ conv = default_conversation.copy()
46
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
47
+
48
+ if no_system_prompt:
49
+ conv.system = ""
50
+
51
+ # Skip the first message if it is not from human
52
+ if messages[0]["from"] != "human":
53
+ messages = messages[1:]
54
+
55
+ # Add a generation prompt if needed
56
+ if add_generation_prompt:
57
+ messages.append({"from": "gpt", "value": None})
58
+
59
+ conv.messages = []
60
+ for turn, message in enumerate(messages):
61
+ role = roles[message["from"]]
62
+ assert role == conv.roles[turn % 2]
63
+ if overrides is not None and message["from"] in overrides:
64
+ conv.append_message(role, overrides[message["from"]])
65
+ else:
66
+ conv.append_message(role, message["value"])
67
+
68
+ return tokenizer_image_token(conv.get_prompt(), tokenizer, return_tensors="pt")
69
+
70
+
71
+ def tokenize_conversation(
72
+ messages: Sequence[Dict[str, str]],
73
+ tokenizer: transformers.PreTrainedTokenizer,
74
+ add_generation_prompt: bool = False,
75
+ overrides: Optional[Dict[str, str]] = None,
76
+ no_system_prompt: bool = False,
77
+ ) -> torch.Tensor:
78
+ # Normalize the conversation before tokenization
79
+ for message in messages:
80
+ message["value"] = message["value"].strip()
81
+
82
+ if default_conversation.sep_style != SeparatorStyle.AUTO:
83
+ return tokenize_conversation_legacy(
84
+ messages,
85
+ tokenizer,
86
+ add_generation_prompt=add_generation_prompt,
87
+ overrides=overrides,
88
+ no_system_prompt=no_system_prompt,
89
+ )
90
+
91
+ conversation = []
92
+ for m in messages:
93
+ message = {}
94
+ if m["from"] == "human":
95
+ message["role"] = "user"
96
+ elif m["from"] == "gpt":
97
+ message["role"] = "assistant"
98
+ else:
99
+ raise ValueError(f"Unexpected sender '{m['from']}' in conversation entry.")
100
+
101
+ message["content"] = m["value"]
102
+ if overrides is not None and m["from"] in overrides:
103
+ message["content"] = overrides[m["from"]]
104
+ conversation.append(message)
105
+
106
+ if no_system_prompt:
107
+ conversation = [{"role": "system", "content": ""}] + conversation
108
+
109
+ text = tokenizer.apply_chat_template(
110
+ conversation,
111
+ add_generation_prompt=add_generation_prompt,
112
+ tokenize=False,
113
+ )
114
+ return tokenizer_image_token(text, tokenizer, return_tensors="pt")
115
+
116
+
117
+ def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None:
118
+ if not hasattr(tokenizer, "sentinel_token"):
119
+ tokenizer.add_tokens([SENTINEL_TOKEN], special_tokens=True)
120
+ tokenizer.sentinel_token = SENTINEL_TOKEN
121
+ tokenizer.sentinel_token_id = tokenizer.convert_tokens_to_ids(SENTINEL_TOKEN)
122
+
123
+
124
+ def preprocess_conversation(
125
+ conversation: Sequence[Dict[str, str]],
126
+ tokenizer: transformers.PreTrainedTokenizer,
127
+ no_system_prompt: bool = False,
128
+ retried: bool = False,
129
+ ) -> Dict[str, Any]:
130
+ inputs = tokenize_conversation(conversation, tokenizer, no_system_prompt=no_system_prompt)
131
+ labels = torch.ones_like(inputs) * IGNORE_INDEX
132
+
133
+ # Generate the template by replacing the assistant's response with a sentinel.
134
+ _maybe_add_sentinel_token(tokenizer)
135
+ template = tokenize_conversation(
136
+ conversation, tokenizer, overrides={"gpt": SENTINEL_TOKEN}, no_system_prompt=no_system_prompt
137
+ )
138
+
139
+ # Remove sentinel tokens from the template.
140
+ mask = torch.ones_like(template, dtype=torch.bool)
141
+ for k in range(template.size(0) - 1):
142
+ if template[k] == tokenizer.sentinel_token_id:
143
+ mask[k : k + 2] = False
144
+ # NOTE(zhijianl): This is to handle the corner case where there is an empty token before the sentinel token.
145
+ if k > 0 and retried:
146
+ mask[k - 1] = False
147
+ template = template[mask]
148
+
149
+ # Match the tokenized conversation with the template (with no assistant's response).
150
+ # Every token that is not matched will be included in the label for training.
151
+ p = 0
152
+ for k in range(inputs.size(0)):
153
+ if p < template.size(0) and inputs[k] == template[p]:
154
+ p += 1
155
+ else:
156
+ labels[k] = inputs[k]
157
+
158
+ # Mask all tokens in the label if the template is not fully matched.
159
+ if p < template.size(0):
160
+ if not retried:
161
+ return preprocess_conversation(
162
+ conversation,
163
+ tokenizer,
164
+ no_system_prompt=no_system_prompt,
165
+ retried=True,
166
+ )
167
+ print(f"Failed to process the conversation: '{conversation}'. All tokens will be masked in the label.")
168
+ labels[:] = IGNORE_INDEX
169
+
170
+ return {"input_ids": inputs, "labels": labels}
171
+
172
+
173
+ def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]:
174
+ _maybe_add_sentinel_token(tokenizer)
175
+ template = tokenize_conversation(DUMMY_CONVERSATION, tokenizer, overrides={"gpt": SENTINEL_TOKEN})
176
+
177
+ stop_tokens = {tokenizer.eos_token}
178
+ for k in range(template.size(0) - 1):
179
+ if template[k] == tokenizer.sentinel_token_id:
180
+ stop_token = tokenizer.decode(template[k + 1])
181
+ stop_tokens.add(stop_token)
182
+ return list(stop_tokens)
trainer_state.json ADDED
The diff for this file is too large to render. See raw diff
 
utils.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
17
+ import os
18
+ import os.path as osp
19
+
20
+ from huggingface_hub import repo_exists, snapshot_download
21
+ from huggingface_hub.utils import HFValidationError, validate_repo_id
22
+ from transformers import AutoConfig, PretrainedConfig
23
+
24
+
25
+ def get_model_config(config):
26
+ default_keys = ["llm_cfg", "vision_tower_cfg", "mm_projector_cfg"]
27
+
28
+ if hasattr(config, "_name_or_path") and len(config._name_or_path) >= 2:
29
+ root_path = config._name_or_path
30
+ else:
31
+ root_path = config.resume_path
32
+
33
+ # download from huggingface
34
+ if root_path is not None and not osp.exists(root_path):
35
+ try:
36
+ valid_hf_repo = repo_exists(root_path)
37
+ except HFValidationError as e:
38
+ valid_hf_repo = False
39
+ if valid_hf_repo:
40
+ root_path = snapshot_download(root_path)
41
+
42
+ return_list = []
43
+ for key in default_keys:
44
+ cfg = getattr(config, key, None)
45
+ if isinstance(cfg, dict):
46
+ try:
47
+ return_list.append(os.path.join(root_path, key[:-4]))
48
+ except:
49
+ raise ValueError(f"Cannot find resume path in config for {key}!")
50
+ elif isinstance(cfg, PretrainedConfig):
51
+ return_list.append(os.path.join(root_path, key[:-4]))
52
+ elif isinstance(cfg, str):
53
+ return_list.append(cfg)
54
+
55
+ return return_list
56
+
57
+
58
+ def get_model_config_fp8(config):
59
+ default_keys = ["llm_cfg", "vision_tower_cfg", "mm_projector_cfg"]
60
+
61
+ if hasattr(config, "_name_or_path") and len(config._name_or_path) >= 2:
62
+ root_path = config._name_or_path
63
+ else:
64
+ root_path = config.resume_path
65
+
66
+ # download from huggingface
67
+ if root_path is not None and not osp.exists(root_path):
68
+ try:
69
+ valid_hf_repo = repo_exists(root_path)
70
+ except HFValidationError as e:
71
+ valid_hf_repo = False
72
+ if valid_hf_repo:
73
+ root_path = snapshot_download(root_path)
74
+
75
+ return_list = []
76
+ for key in default_keys:
77
+ cfg = getattr(config, key, None)
78
+ if isinstance(cfg, dict):
79
+ try:
80
+ return_list.append(os.path.join(root_path, key[:-4]))
81
+ except:
82
+ raise ValueError(f"Cannot find resume path in config for {key}!")
83
+ elif isinstance(cfg, PretrainedConfig):
84
+ return_list.append(os.path.join(root_path, key[:-4]))
85
+ elif isinstance(cfg, str):
86
+ return_list.append(cfg)
87
+
88
+ # fp8_llm
89
+ key = "fp8_llm_cfg"
90
+ directory_path = os.path.join(root_path, key[:-4])
91
+ assert os.path.isdir(directory_path) and os.listdir(
92
+ directory_path
93
+ ), "You need to first convert the model weights to FP8 explicitly."
94
+ return_list.append(directory_path)
95
+
96
+ return return_list
97
+
98
+
99
+ def get_model_config_fp8(config):
100
+ default_keys = ["llm_cfg", "vision_tower_cfg", "mm_projector_cfg"]
101
+
102
+ if hasattr(config, "_name_or_path") and len(config._name_or_path) >= 2:
103
+ root_path = config._name_or_path
104
+ else:
105
+ root_path = config.resume_path
106
+
107
+ # download from huggingface
108
+ if root_path is not None and not osp.exists(root_path):
109
+ try:
110
+ valid_hf_repo = repo_exists(root_path)
111
+ except HFValidationError as e:
112
+ valid_hf_repo = False
113
+ if valid_hf_repo:
114
+ root_path = snapshot_download(root_path)
115
+
116
+ return_list = []
117
+ for key in default_keys:
118
+ cfg = getattr(config, key, None)
119
+ if isinstance(cfg, dict):
120
+ try:
121
+ return_list.append(os.path.join(root_path, key[:-4]))
122
+ except:
123
+ raise ValueError(f"Cannot find resume path in config for {key}!")
124
+ elif isinstance(cfg, PretrainedConfig):
125
+ return_list.append(os.path.join(root_path, key[:-4]))
126
+ elif isinstance(cfg, str):
127
+ return_list.append(cfg)
128
+
129
+ # fp8_llm
130
+ key = "fp8_llm_cfg"
131
+ directory_path = os.path.join(root_path, key[:-4])
132
+ assert os.path.isdir(directory_path) and os.listdir(
133
+ directory_path
134
+ ), "You need to first convert the model weights to FP8 explicitly."
135
+ return_list.append(directory_path)
136
+
137
+ return return_list
138
+
139
+
140
+ def is_mm_model(model_path):
141
+ """
142
+ Check if the model at the given path is a visual language model.
143
+
144
+ Args:
145
+ model_path (str): The path to the model.
146
+
147
+ Returns:
148
+ bool: True if the model is an MM model, False otherwise.
149
+ """
150
+ config = AutoConfig.from_pretrained(model_path)
151
+ architectures = config.architectures
152
+ for architecture in architectures:
153
+ if "llava" in architecture.lower():
154
+ return True
155
+ return False
156
+
157
+
158
+ def auto_upgrade(config):
159
+ cfg = AutoConfig.from_pretrained(config)
160
+ if "llava" in config and "llava" not in cfg.model_type:
161
+ assert cfg.model_type == "llama"
162
+ print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
163
+ print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
164
+ confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
165
+ if confirm.lower() in ["y", "yes"]:
166
+ print("Upgrading checkpoint...")
167
+ assert len(cfg.architectures) == 1
168
+ setattr(cfg.__class__, "model_type", "llava")
169
+ cfg.architectures[0] = "LlavaLlamaForCausalLM"
170
+ cfg.save_pretrained(config)
171
+ print("Checkpoint upgraded.")
172
+ else:
173
+ print("Checkpoint upgrade aborted.")
174
+ exit(1)
vision_tower/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "runs/train/qwen25-7b-3x3-sft-20241115225329/model/vision_tower",
3
+ "architectures": [
4
+ "SiglipVisionModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "hidden_act": "gelu_pytorch_tanh",
8
+ "hidden_size": 1152,
9
+ "image_size": 448,
10
+ "intermediate_size": 4304,
11
+ "layer_norm_eps": 1e-06,
12
+ "model_type": "siglip_vision_model",
13
+ "num_attention_heads": 16,
14
+ "num_channels": 3,
15
+ "num_hidden_layers": 27,
16
+ "num_image_tokens": 256,
17
+ "patch_size": 14,
18
+ "projection_dim": 2048,
19
+ "projector_hidden_act": "gelu_fast",
20
+ "torch_dtype": "bfloat16",
21
+ "transformers_version": "4.46.0",
22
+ "vision_use_head": false
23
+ }
vision_tower/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a46ef371610c7293e9d0b06b2e6f8f0644544c3307702adc64d3d4147a6acba
3
+ size 826707904
vision_tower/preprocessor_config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_convert_rgb": null,
3
+ "do_normalize": true,
4
+ "do_rescale": true,
5
+ "do_resize": true,
6
+ "image_mean": [
7
+ 0.5,
8
+ 0.5,
9
+ 0.5
10
+ ],
11
+ "image_processor_type": "SiglipImageProcessor",
12
+ "image_std": [
13
+ 0.5,
14
+ 0.5,
15
+ 0.5
16
+ ],
17
+ "processor_class": "SiglipProcessor",
18
+ "resample": 3,
19
+ "rescale_factor": 0.00392156862745098,
20
+ "size": {
21
+ "height": 448,
22
+ "width": 448
23
+ }
24
+ }