darsoarafa commited on
Commit
6f0b9f8
·
verified ·
1 Parent(s): adde0a8

Upload 25 files

Browse files
dam/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ from .describe_anything_model import *
18
+ from .model import *
dam/describe_anything_model.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import numpy as np
20
+ from PIL import Image
21
+ from .model.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
22
+ from .model.conversation import SeparatorStyle, conv_templates
23
+ from .model.mm_utils import KeywordsStoppingCriteria, process_image, tokenizer_image_token
24
+ from .model import get_model_name_from_path, load_pretrained_model
25
+ from transformers import TextIteratorStreamer
26
+ from threading import Thread
27
+
28
+ class DescribeAnythingModel(nn.Module):
29
+ def __init__(self, model_path, conv_mode, prompt_mode, **kwargs):
30
+ super().__init__()
31
+
32
+ self.model_path = model_path
33
+ self.conv_mode = conv_mode
34
+ self.prompt_mode = prompt_mode
35
+
36
+ if isinstance(model_path, str):
37
+ self.tokenizer, self.model, _, _ = load_pretrained_model(model_path, None, None, **kwargs)
38
+ self.model_name = get_model_name_from_path(model_path)
39
+ else:
40
+ # model_path is actually a dict with model, tokenizer, and (optionally) model_name
41
+ self.model = model_path["model"]
42
+ self.tokenizer = model_path["tokenizer"]
43
+ self.model_name = model_path.get("model_name", None)
44
+
45
+ image_processor = self.model.vision_tower.image_processor
46
+ self.model.config.image_processor = image_processor
47
+
48
+ def get_prompt(self, qs):
49
+ if DEFAULT_IMAGE_TOKEN not in qs:
50
+ raise ValueError("no <image> tag found in input.")
51
+
52
+ conv = conv_templates[self.conv_mode].copy()
53
+ conv.append_message(conv.roles[0], qs)
54
+ conv.append_message(conv.roles[1], None)
55
+ prompt = conv.get_prompt()
56
+
57
+ return prompt, conv
58
+
59
+ @staticmethod
60
+ def mask_to_box(mask_np):
61
+ mask_coords = np.argwhere(mask_np)
62
+ y0, x0 = mask_coords.min(axis=0)
63
+ y1, x1 = mask_coords.max(axis=0) + 1
64
+
65
+ h = y1 - y0
66
+ w = x1 - x0
67
+
68
+ return x0, y0, w, h
69
+
70
+ @classmethod
71
+ def crop_image(cls, pil_img, mask_np, crop_mode, min_box_w=48, min_box_h=48):
72
+ if crop_mode == "full":
73
+ # no crop
74
+ info = dict(mask_np=mask_np)
75
+ return pil_img, info
76
+
77
+ if crop_mode == "crop":
78
+ # crop image and mask
79
+ x0, y0, w, h = cls.mask_to_box(mask_np)
80
+ img_np = np.asarray(pil_img)
81
+ assert img_np.shape[:2] == mask_np.shape, f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}"
82
+ cropped_mask_np = mask_np[y0:y0+h, x0:x0+w]
83
+ cropped_img_np = img_np[y0:y0+h, x0:x0+w]
84
+ cropped_pil_img = Image.fromarray(cropped_img_np)
85
+ elif crop_mode == "context_crop":
86
+ # crop image and mask
87
+ x0, y0, w, h = cls.mask_to_box(mask_np)
88
+ img_np = np.asarray(pil_img)
89
+ assert img_np.shape[:2] == mask_np.shape, f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}"
90
+ img_h, img_w = img_np.shape[:2]
91
+ cropped_mask_np = mask_np[max(y0-h, 0):min(y0+2*h, img_h), max(x0-w, 0):min(x0+2*w, img_w)]
92
+ cropped_img_np = img_np[max(y0-h, 0):min(y0+2*h, img_h), max(x0-w, 0):min(x0+2*w, img_w)]
93
+ cropped_pil_img = Image.fromarray(cropped_img_np)
94
+ elif crop_mode == "focal_crop":
95
+ # crop image and mask
96
+ x0, y0, w, h = cls.mask_to_box(mask_np)
97
+ img_np = np.asarray(pil_img)
98
+ assert img_np.shape[:2] == mask_np.shape, f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}"
99
+ img_h, img_w = img_np.shape[:2]
100
+
101
+ xc, yc = x0 + w/2, y0 + h/2
102
+ # focal_crop: need to have at least min_box_w and min_box_h pixels, otherwise resizing to (384, 384) leads to artifacts that may be OOD
103
+ w, h = max(w, min_box_w), max(h, min_box_h)
104
+ x0, y0 = int(xc - w / 2), int(yc - h / 2)
105
+
106
+ cropped_mask_np = mask_np[max(y0-h, 0):min(y0+2*h, img_h), max(x0-w, 0):min(x0+2*w, img_w)]
107
+ cropped_img_np = img_np[max(y0-h, 0):min(y0+2*h, img_h), max(x0-w, 0):min(x0+2*w, img_w)]
108
+ cropped_pil_img = Image.fromarray(cropped_img_np)
109
+ elif crop_mode == "crop_mask":
110
+ # crop image and mask
111
+ x0, y0, w, h = cls.mask_to_box(mask_np)
112
+ img_np = np.asarray(pil_img)
113
+ assert img_np.shape[:2] == mask_np.shape, f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}"
114
+ cropped_mask_np = mask_np[y0:y0+h, x0:x0+w]
115
+ cropped_img_np = img_np[y0:y0+h, x0:x0+w]
116
+ # Mask the image
117
+ cropped_img_np = cropped_img_np * cropped_mask_np[..., None]
118
+ cropped_pil_img = Image.fromarray(cropped_img_np)
119
+ else:
120
+ raise ValueError(f"Unsupported crop_mode: {crop_mode}")
121
+
122
+ info = dict(mask_np=cropped_mask_np)
123
+ return cropped_pil_img, info
124
+
125
+ def get_description(self, image_pil, mask_pil, query, streaming=False, temperature=0.2, top_p=0.5, num_beams=1, max_new_tokens=512, **kwargs):
126
+ # kwargs is passed to generation_kwargs: https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig
127
+
128
+ prompt, conv = self.get_prompt(query)
129
+ if not isinstance(image_pil, (list, tuple)):
130
+ assert not isinstance(mask_pil, (list, tuple)), "image_pil and mask_pil must be both list or tuple or not list or tuple."
131
+ image_pils = [image_pil]
132
+ mask_pils = [mask_pil]
133
+ else:
134
+ image_pils = image_pil
135
+ mask_pils = mask_pil
136
+ description = self.get_description_from_prompt(image_pils, mask_pils, prompt, conv, streaming=streaming, temperature=temperature, top_p=top_p, num_beams=num_beams, max_new_tokens=max_new_tokens, **kwargs)
137
+
138
+ return description
139
+
140
+ def get_image_tensor(self, image_pil, mask_pil, crop_mode, crop_mode2):
141
+ # the pil has True/False (if the value is non-zero, then we treat it as True)
142
+ mask_np = (np.asarray(mask_pil) > 0).astype(np.uint8)
143
+ images_tensor, image_info = process_image(image_pil, self.model.config, None, pil_preprocess_fn=lambda pil_img: self.crop_image(image_pil, mask_np=mask_np, crop_mode=crop_mode))
144
+ images_tensor = images_tensor[None].to(self.model.device, dtype=torch.float16)
145
+
146
+ mask_np = image_info["mask_np"]
147
+ mask_pil = Image.fromarray(mask_np * 255)
148
+
149
+ masks_tensor = process_image(mask_pil, self.model.config, None)
150
+ masks_tensor = masks_tensor[None].to(self.model.device, dtype=torch.float16)
151
+
152
+ images_tensor = torch.cat((images_tensor, masks_tensor[:, :1, ...]), dim=1)
153
+
154
+ if crop_mode2 is not None:
155
+ images_tensor2, image_info2 = process_image(image_pil, self.model.config, None, pil_preprocess_fn=lambda pil_img: self.crop_image(pil_img, mask_np=mask_np, crop_mode=crop_mode2))
156
+ images_tensor2 = images_tensor2[None].to(self.model.device, dtype=torch.float16)
157
+
158
+ mask_np2 = image_info2["mask_np"]
159
+ mask_pil2 = Image.fromarray(mask_np2 * 255)
160
+
161
+ masks_tensor2 = process_image(mask_pil2, self.model.config, None)
162
+ masks_tensor2 = masks_tensor2[None].to(self.model.device, dtype=torch.float16)
163
+
164
+ images_tensor2 = torch.cat((images_tensor2, masks_tensor2[:, :1, ...]), dim=1)
165
+ else:
166
+ images_tensor2 = None
167
+
168
+ return torch.cat((images_tensor, images_tensor2), dim=1) if images_tensor2 is not None else images_tensor
169
+
170
+ def get_description_from_prompt(self, image_pils, mask_pils, prompt, conv, streaming=False, temperature=0.2, top_p=0.5, num_beams=1, max_new_tokens=512, **kwargs):
171
+ if streaming:
172
+ return self.get_description_from_prompt_iterator(image_pils, mask_pils, prompt, conv, streaming=True, temperature=temperature, top_p=top_p, num_beams=num_beams, max_new_tokens=max_new_tokens, **kwargs)
173
+ else:
174
+ # If streaming is False, there will be only one output
175
+ output = self.get_description_from_prompt_iterator(image_pils, mask_pils, prompt, conv, streaming=False, temperature=temperature, top_p=top_p, num_beams=num_beams, max_new_tokens=max_new_tokens, **kwargs)
176
+ return next(output)
177
+
178
+ def get_description_from_prompt_iterator(self, image_pils, mask_pils, prompt, conv, streaming=False, temperature=0.2, top_p=0.5, num_beams=1, max_new_tokens=512, **kwargs):
179
+ crop_mode, crop_mode2 = self.prompt_mode.split("+")
180
+ assert crop_mode == "full", "Current prompt only supports first crop as full (non-cropped). If you need other specifications, please update the prompt."
181
+
182
+ assert len(image_pils) == len(mask_pils), f"image_pils and mask_pils must have the same length. Got {len(image_pils)} and {len(mask_pils)}."
183
+ image_tensors = [self.get_image_tensor(image_pil, mask_pil, crop_mode=crop_mode, crop_mode2=crop_mode2) for image_pil, mask_pil in zip(image_pils, mask_pils)]
184
+
185
+ input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
186
+
187
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
188
+ keywords = [stop_str]
189
+ stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
190
+
191
+ streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True) if streaming else None
192
+ generation_kwargs = dict(
193
+ input_ids=input_ids,
194
+ images=image_tensors,
195
+ do_sample=True if temperature > 0 else False,
196
+ use_cache=True,
197
+ stopping_criteria=[stopping_criteria],
198
+ streamer=streamer,
199
+ temperature=temperature,
200
+ top_p=top_p,
201
+ num_beams=num_beams,
202
+ max_new_tokens=max_new_tokens,
203
+ **kwargs
204
+ )
205
+
206
+ if streaming:
207
+ thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
208
+ thread.start()
209
+
210
+ generated_text = ""
211
+ for new_text in streamer:
212
+ generated_text += new_text
213
+ if stop_str in generated_text:
214
+ generated_text = generated_text[:generated_text.find(stop_str)]
215
+ break
216
+ yield new_text
217
+
218
+ thread.join()
219
+ else:
220
+ with torch.inference_mode():
221
+ output_ids = self.model.generate(**generation_kwargs)
222
+
223
+ outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
224
+ outputs = outputs.strip()
225
+ if outputs.endswith(stop_str):
226
+ outputs = outputs[: -len(stop_str)]
227
+ outputs = outputs.strip()
228
+
229
+ yield outputs
dam/model/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ from .constants import *
18
+ from .conversation import *
19
+ from .mm_utils import *
20
+ from .model_utils import *
dam/model/configuration_llava.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ from transformers import PretrainedConfig
18
+
19
+
20
+ class LlavaConfig(PretrainedConfig):
21
+ model_type = "llava"
22
+
23
+ def __init__(
24
+ self,
25
+ llm_cfg=None,
26
+ vision_tower_cfg=None,
27
+ mm_projector_cfg=None,
28
+ mask_encoder_cfg=None,
29
+ context_provider_cfg=None,
30
+ architectures=None,
31
+ resume_path=None,
32
+ hidden_size=None,
33
+ mm_hidden_size=None,
34
+ image_aspect_ratio=None,
35
+ num_video_frames=None,
36
+ mm_vision_select_layer=None,
37
+ mm_vision_select_feature=None,
38
+ mm_use_im_start_end=False,
39
+ mm_use_im_patch_token=True,
40
+ mm_projector_lr=None,
41
+ vision_resolution=None,
42
+ interpolate_mode=None,
43
+ s2=None,
44
+ s2_scales=None,
45
+ s2_max_split_size=None,
46
+ **kwargs
47
+ ):
48
+ super().__init__()
49
+ self.architectures = architectures
50
+ self.llm_cfg = llm_cfg
51
+ self.vision_tower_cfg = vision_tower_cfg
52
+ self.mm_projector_cfg = mm_projector_cfg
53
+ self.mask_encoder_cfg = mask_encoder_cfg
54
+ self.context_provider_cfg = context_provider_cfg
55
+ self.resume_path = resume_path
56
+
57
+ self.hidden_size = hidden_size
58
+ self.mm_hidden_size = mm_hidden_size
59
+ self.image_aspect_ratio = image_aspect_ratio
60
+ self.num_video_frames = num_video_frames
61
+ self.mm_vision_select_layer = mm_vision_select_layer
62
+ self.mm_vision_select_feature = mm_vision_select_feature
63
+ self.mm_use_im_start_end = mm_use_im_start_end
64
+ self.mm_use_im_start_end = mm_use_im_start_end
65
+ self.mm_use_im_patch_token = mm_use_im_patch_token
66
+ self.mm_projector_lr = mm_projector_lr
67
+ self.vision_resolution = vision_resolution
68
+ self.interpolate_mode = interpolate_mode
69
+ self.s2 = s2
70
+ self.s2_scales = s2_scales
71
+ self.s2_max_split_size = s2_max_split_size
dam/model/constants.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
17
+
18
+
19
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
20
+ WORKER_HEART_BEAT_INTERVAL = 15
21
+
22
+ LOGDIR = "."
23
+
24
+ # Model Constants
25
+ IGNORE_INDEX = -100
26
+ IMAGE_TOKEN_INDEX = -200
27
+ MASK_TOKEN_INDEX = -300
28
+ DEFAULT_IMAGE_TOKEN = "<image>"
29
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
30
+ DEFAULT_IM_START_TOKEN = "<im_start>"
31
+ DEFAULT_IM_END_TOKEN = "<im_end>"
32
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
dam/model/conversation.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
17
+
18
+
19
+ import dataclasses
20
+ from enum import auto, Enum
21
+ from typing import List
22
+
23
+
24
+ class SeparatorStyle(Enum):
25
+ """Different separator style."""
26
+ SINGLE = auto()
27
+ TWO = auto()
28
+ MPT = auto()
29
+ PLAIN = auto()
30
+ LLAMA_2 = auto()
31
+ MISTRAL = auto()
32
+ LLAMA_3 = auto()
33
+
34
+
35
+ @dataclasses.dataclass
36
+ class Conversation:
37
+ """A class that keeps all conversation history."""
38
+ system: str
39
+ roles: List[str]
40
+ messages: List[List[str]]
41
+ offset: int
42
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
43
+ sep: str = "###"
44
+ sep2: str = None
45
+ version: str = "Unknown"
46
+
47
+ skip_next: bool = False
48
+
49
+ def get_prompt(self):
50
+ messages = self.messages
51
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
52
+ messages = self.messages.copy()
53
+ init_role, init_msg = messages[0].copy()
54
+ init_msg = init_msg[0].replace("<image>", "").strip()
55
+ if 'mmtag' in self.version:
56
+ messages[0] = (init_role, init_msg)
57
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
58
+ messages.insert(1, (self.roles[1], "Received."))
59
+ else:
60
+ messages[0] = (init_role, "<image>\n" + init_msg)
61
+
62
+ if self.sep_style == SeparatorStyle.SINGLE:
63
+ ret = self.system + self.sep
64
+ for role, message in messages:
65
+ if message:
66
+ if type(message) is tuple:
67
+ message, _, _ = message
68
+ ret += role + ": " + message + self.sep
69
+ else:
70
+ ret += role + ":"
71
+ elif self.sep_style == SeparatorStyle.TWO:
72
+ seps = [self.sep, self.sep2]
73
+ ret = self.system + seps[0]
74
+ for i, (role, message) in enumerate(messages):
75
+ if message:
76
+ if type(message) is tuple:
77
+ message, _, _ = message
78
+ ret += role + ": " + message + seps[i % 2]
79
+ else:
80
+ ret += role + ":"
81
+ elif self.sep_style == SeparatorStyle.LLAMA_3:
82
+ ret = self.system + self.sep
83
+ for role, message in messages:
84
+ if message:
85
+ if type(message) is tuple:
86
+ message = message[0]
87
+ ret += role + message + self.sep
88
+ else:
89
+ ret += role
90
+ elif self.sep_style == SeparatorStyle.MPT:
91
+ ret = self.system + self.sep
92
+ for role, message in messages:
93
+ if message:
94
+ if type(message) is tuple:
95
+ message, _, _ = message
96
+ ret += role + message + self.sep
97
+ else:
98
+ ret += role
99
+ elif self.sep_style == SeparatorStyle.LLAMA_2 or self.sep_style == SeparatorStyle.MISTRAL:
100
+ if self.sep_style == SeparatorStyle.LLAMA_2:
101
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
102
+ else:
103
+ wrap_sys = lambda msg: f"{msg}" + ("\n" if msg else "")
104
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
105
+ ret = ""
106
+ if self.sep_style == SeparatorStyle.MISTRAL:
107
+ ret += "<s>"
108
+
109
+ for i, (role, message) in enumerate(messages):
110
+ if i == 0:
111
+ assert message, "first message should not be none"
112
+ assert role == self.roles[0], "first message should come from user"
113
+ if message:
114
+ if type(message) is tuple:
115
+ message, _, _ = message
116
+ if i == 0: message = wrap_sys(self.system) + message
117
+ if i % 2 == 0:
118
+ message = wrap_inst(message)
119
+ ret += self.sep + message
120
+ else:
121
+ if self.sep_style == SeparatorStyle.LLAMA_2:
122
+ ret += " " + message + " " + self.sep2
123
+ else:
124
+ ret += message + self.sep2
125
+ else:
126
+ ret += ""
127
+ ret = ret.lstrip(self.sep)
128
+ elif self.sep_style == SeparatorStyle.PLAIN:
129
+ seps = [self.sep, self.sep2]
130
+ ret = self.system
131
+ for i, (role, message) in enumerate(messages):
132
+ if message:
133
+ if type(message) is tuple:
134
+ message, _, _ = message
135
+ ret += message + seps[i % 2]
136
+ else:
137
+ ret += ""
138
+ else:
139
+ raise ValueError(f"Invalid style: {self.sep_style}")
140
+
141
+ return ret
142
+
143
+ def append_message(self, role, message):
144
+ self.messages.append([role, message])
145
+
146
+ def get_images(self, return_pil=False):
147
+ images = []
148
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
149
+ if i % 2 == 0:
150
+ if type(msg) is tuple:
151
+ import base64
152
+ from io import BytesIO
153
+ from PIL import Image
154
+ msg, image, image_process_mode = msg
155
+ if image_process_mode == "Pad":
156
+ def expand2square(pil_img, background_color=(122, 116, 104)):
157
+ width, height = pil_img.size
158
+ if width == height:
159
+ return pil_img
160
+ elif width > height:
161
+ result = Image.new(pil_img.mode, (width, width), background_color)
162
+ result.paste(pil_img, (0, (width - height) // 2))
163
+ return result
164
+ else:
165
+ result = Image.new(pil_img.mode, (height, height), background_color)
166
+ result.paste(pil_img, ((height - width) // 2, 0))
167
+ return result
168
+ image = expand2square(image)
169
+ elif image_process_mode in ["Default", "Crop"]:
170
+ pass
171
+ elif image_process_mode == "Resize":
172
+ image = image.resize((336, 336))
173
+ else:
174
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
175
+ max_hw, min_hw = max(image.size), min(image.size)
176
+ aspect_ratio = max_hw / min_hw
177
+ max_len, min_len = 800, 400
178
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
179
+ longest_edge = int(shortest_edge * aspect_ratio)
180
+ W, H = image.size
181
+ if longest_edge != max(image.size):
182
+ if H > W:
183
+ H, W = longest_edge, shortest_edge
184
+ else:
185
+ H, W = shortest_edge, longest_edge
186
+ image = image.resize((W, H))
187
+ if return_pil:
188
+ images.append(image)
189
+ else:
190
+ buffered = BytesIO()
191
+ image.save(buffered, format="PNG")
192
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
193
+ images.append(img_b64_str)
194
+ return images
195
+
196
+ def to_gradio_chatbot(self):
197
+ ret = []
198
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
199
+ if i % 2 == 0:
200
+ if type(msg) is tuple:
201
+ import base64
202
+ from io import BytesIO
203
+ msg, image, image_process_mode = msg
204
+ max_hw, min_hw = max(image.size), min(image.size)
205
+ aspect_ratio = max_hw / min_hw
206
+ max_len, min_len = 800, 400
207
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
208
+ longest_edge = int(shortest_edge * aspect_ratio)
209
+ W, H = image.size
210
+ if H > W:
211
+ H, W = longest_edge, shortest_edge
212
+ else:
213
+ H, W = shortest_edge, longest_edge
214
+ image = image.resize((W, H))
215
+ buffered = BytesIO()
216
+ image.save(buffered, format="JPEG")
217
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
218
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
219
+ msg = img_str + msg.replace('<image>', '').strip()
220
+ ret.append([msg, None])
221
+ else:
222
+ ret.append([msg, None])
223
+ else:
224
+ ret[-1][-1] = msg
225
+ return ret
226
+
227
+ def copy(self):
228
+ return Conversation(
229
+ system=self.system,
230
+ roles=self.roles,
231
+ messages=[[x, y] for x, y in self.messages],
232
+ offset=self.offset,
233
+ sep_style=self.sep_style,
234
+ sep=self.sep,
235
+ sep2=self.sep2,
236
+ version=self.version)
237
+
238
+ def dict(self):
239
+ if len(self.get_images()) > 0:
240
+ return {
241
+ "system": self.system,
242
+ "roles": self.roles,
243
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
244
+ "offset": self.offset,
245
+ "sep": self.sep,
246
+ "sep2": self.sep2,
247
+ }
248
+ return {
249
+ "system": self.system,
250
+ "roles": self.roles,
251
+ "messages": self.messages,
252
+ "offset": self.offset,
253
+ "sep": self.sep,
254
+ "sep2": self.sep2,
255
+ }
256
+
257
+
258
+ conv_vicuna_v0 = Conversation(
259
+ system="A chat between a curious human and an artificial intelligence assistant. "
260
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
261
+ roles=("Human", "Assistant"),
262
+ messages=(
263
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
264
+ ("Assistant",
265
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
266
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
267
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
268
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
269
+ "renewable and non-renewable energy sources:\n"
270
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
271
+ "energy sources are finite and will eventually run out.\n"
272
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
273
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
274
+ "and other negative effects.\n"
275
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
276
+ "have lower operational costs than non-renewable sources.\n"
277
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
278
+ "locations than non-renewable sources.\n"
279
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
280
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
281
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
282
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
283
+ ),
284
+ offset=2,
285
+ sep_style=SeparatorStyle.SINGLE,
286
+ sep="###",
287
+ )
288
+
289
+ conv_vicuna_v1 = Conversation(
290
+ system="A chat between a curious user and an artificial intelligence assistant. "
291
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
292
+ roles=("USER", "ASSISTANT"),
293
+ version="v1",
294
+ messages=(),
295
+ offset=0,
296
+ sep_style=SeparatorStyle.TWO,
297
+ sep=" ",
298
+ sep2="</s>",
299
+ )
300
+
301
+ # kentang-mit@: This conversation template is designed for SFT on VFLAN.
302
+ conv_vicuna_v1_nosys = Conversation(
303
+ system="",
304
+ roles=("USER", "ASSISTANT"),
305
+ version="v1_nosys",
306
+ messages=(),
307
+ offset=0,
308
+ sep_style=SeparatorStyle.TWO,
309
+ sep=" ",
310
+ sep2="</s>",
311
+ )
312
+
313
+ conv_llama_2 = Conversation(
314
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
315
+
316
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
317
+ roles=("USER", "ASSISTANT"),
318
+ version="llama_v2",
319
+ messages=(),
320
+ offset=0,
321
+ sep_style=SeparatorStyle.LLAMA_2,
322
+ sep="<s>",
323
+ sep2="</s>",
324
+ )
325
+
326
+ conv_mistral = Conversation(
327
+ system="",
328
+ roles=("USER", "ASSISTANT"),
329
+ version="mistral",
330
+ messages=(),
331
+ offset=0,
332
+ sep_style=SeparatorStyle.MISTRAL,
333
+ sep="",
334
+ sep2="</s>",
335
+ )
336
+
337
+ conv_llava_llama_2 = Conversation(
338
+ system="You are a helpful language and vision assistant. "
339
+ "You are able to understand the visual content that the user provides, "
340
+ "and assist the user with a variety of tasks using natural language.",
341
+ roles=("USER", "ASSISTANT"),
342
+ version="llama_v2",
343
+ messages=(),
344
+ offset=0,
345
+ sep_style=SeparatorStyle.LLAMA_2,
346
+ sep="<s>",
347
+ sep2="</s>",
348
+ )
349
+
350
+ conv_mpt = Conversation(
351
+ system="""<|im_start|>system
352
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
353
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
354
+ version="mpt",
355
+ messages=(),
356
+ offset=0,
357
+ sep_style=SeparatorStyle.MPT,
358
+ sep="<|im_end|>",
359
+ )
360
+
361
+ conv_llava_plain = Conversation(
362
+ system="",
363
+ roles=("", ""),
364
+ messages=(
365
+ ),
366
+ offset=0,
367
+ sep_style=SeparatorStyle.PLAIN,
368
+ sep="\n",
369
+ )
370
+
371
+ conv_llava_v0 = Conversation(
372
+ system="A chat between a curious human and an artificial intelligence assistant. "
373
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
374
+ roles=("Human", "Assistant"),
375
+ messages=(
376
+ ),
377
+ offset=0,
378
+ sep_style=SeparatorStyle.SINGLE,
379
+ sep="###",
380
+ )
381
+
382
+ conv_llava_v0_mmtag = Conversation(
383
+ system="A chat between a curious user and an artificial intelligence assistant. "
384
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
385
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
386
+ roles=("Human", "Assistant"),
387
+ messages=(
388
+ ),
389
+ offset=0,
390
+ sep_style=SeparatorStyle.SINGLE,
391
+ sep="###",
392
+ version="v0_mmtag",
393
+ )
394
+
395
+ conv_llava_v1 = Conversation(
396
+ system="A chat between a curious human and an artificial intelligence assistant. "
397
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
398
+ roles=("USER", "ASSISTANT"),
399
+ version="v1",
400
+ messages=(),
401
+ offset=0,
402
+ sep_style=SeparatorStyle.TWO,
403
+ sep=" ",
404
+ sep2="</s>",
405
+ )
406
+
407
+
408
+
409
+ conv_llava_v1_mmtag = Conversation(
410
+ system="A chat between a curious user and an artificial intelligence assistant. "
411
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
412
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
413
+ roles=("USER", "ASSISTANT"),
414
+ messages=(),
415
+ offset=0,
416
+ sep_style=SeparatorStyle.TWO,
417
+ sep=" ",
418
+ sep2="</s>",
419
+ version="v1_mmtag",
420
+ )
421
+
422
+ hermes_2 = Conversation(
423
+ system='<|im_start|>system\nAnswer the questions.',
424
+ roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
425
+ sep_style=SeparatorStyle.MPT,
426
+ sep='<|im_end|>',
427
+ messages=(
428
+ ),
429
+ offset=0,
430
+ version="hermes-2"
431
+ )
432
+
433
+
434
+ # Template added by Yukang. Note (kentang-mit@): sep is <|eot_id|> for official template.
435
+ llama_3_chat = Conversation(
436
+ system="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. "
437
+ "You are able to understand the visual content that the user provides, "
438
+ "and assist the user with a variety of tasks using natural language.",
439
+ roles=("<|start_header_id|>user<|end_header_id|>\n\n",
440
+ "<|start_header_id|>system<|end_header_id|>\n\n"),
441
+ version="llama_v3",
442
+ messages=(),
443
+ offset=0,
444
+ sep_style=SeparatorStyle.LLAMA_3,
445
+ sep="<|end_of_text|>",
446
+ )
447
+
448
+
449
+ default_conversation = conv_vicuna_v1
450
+ conv_templates = {
451
+ "default": conv_vicuna_v0,
452
+ "hermes-2": hermes_2,
453
+ "llama_3": llama_3_chat,
454
+ "v0": conv_vicuna_v0,
455
+ "v1": conv_vicuna_v1,
456
+ "vicuna_v1": conv_vicuna_v1,
457
+ "vicuna_v1_nosys": conv_vicuna_v1_nosys,
458
+ "llama_2": conv_llama_2,
459
+ "mistral": conv_mistral,
460
+
461
+ "plain": conv_llava_plain,
462
+ "v0_plain": conv_llava_plain,
463
+ "llava_v0": conv_llava_v0,
464
+ "v0_mmtag": conv_llava_v0_mmtag,
465
+ "llava_v1": conv_llava_v1,
466
+ "v1_mmtag": conv_llava_v1_mmtag,
467
+ "llava_llama_2": conv_llava_llama_2,
468
+
469
+ "mpt": conv_mpt,
470
+ }
471
+
472
+ # if __name__ == "__main__":
473
+ # print(default_conversation.get_prompt())
dam/model/language_model/builder.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ # This file is modified from https://github.com/haotian-liu/LLaVA/ and https://github.com/NVlabs/VILA/
18
+
19
+
20
+ import torch # noqa
21
+ import math
22
+ import warnings
23
+ import os.path as osp
24
+ from transformers import PretrainedConfig, PreTrainedModel
25
+ from transformers import (
26
+ AutoTokenizer,
27
+ AutoModelForCausalLM,
28
+ AutoConfig,
29
+ )
30
+
31
+
32
+ def has_tokenizer(path):
33
+ if (
34
+ osp.exists(osp.join(path, "special_tokens_map.json"))
35
+ and osp.exists(osp.join(path, "tokenizer_config.json"))
36
+ and (osp.exists(osp.join(path, "tokenizer.model")) or osp.exists(osp.join(path, "tokenizer.json")))
37
+ ):
38
+ # print("[has_tokenizer]", path, True)
39
+ return True
40
+ from huggingface_hub import HfApi, file_exists
41
+ from huggingface_hub.utils import HFValidationError
42
+ api = HfApi()
43
+ try:
44
+ valid_hf_repo = api.repo_exists(path)
45
+ except HFValidationError:
46
+ valid_hf_repo = False
47
+ if (
48
+ valid_hf_repo
49
+ and file_exists(path, "special_tokens_map.json")
50
+ and file_exists(path, "tokenizer_config.json")
51
+ and (file_exists(path, "tokenizer.model") or file_exists(path, "tokenizer.json"))
52
+ ):
53
+ # print("[has_tokenizer]", path, True)
54
+ return True
55
+ # print("[has_tokenizer]", path, False)
56
+ return False
57
+
58
+
59
+ def context_length_extension(config):
60
+ orig_ctx_len = getattr(config, "max_position_embeddings", None)
61
+ model_max_length = getattr(config, "model_max_length", None)
62
+ if orig_ctx_len and model_max_length > orig_ctx_len:
63
+ print(f"Scaling RoPE from {orig_ctx_len} to {model_max_length}")
64
+ scaling_factor = float(math.ceil(model_max_length / orig_ctx_len))
65
+ config.rope_scaling = {"type": "linear", "factor": scaling_factor}
66
+ return config
67
+
68
+
69
+ def build_llm_and_tokenizer(
70
+ model_name_or_path: str,
71
+ config: PretrainedConfig,
72
+ # config_cls: PretrainedConfig = None,
73
+ # llm_cls: PreTrainedModel = None,
74
+ attn_implementation=None,
75
+ model_max_length=None,
76
+ *args,
77
+ **kwargs,
78
+ ) -> PreTrainedModel:
79
+ # if config_cls is None:
80
+ # config_cls = AutoConfig
81
+ # if llm_cls is None:
82
+ # llm_cls = AutoModelForCausalLM
83
+ # config_cls = AutoConfig
84
+ # llm_cls = AutoModelForCausalLM
85
+ ## extra configuration for llm
86
+ # print("build_llm_and_tokenizer():", model_name_or_path); input("DEBUG")
87
+ llm_cfg = AutoConfig.from_pretrained(model_name_or_path)
88
+ llm_cfg._attn_implementation = attn_implementation
89
+ llm_cfg.model_max_length = model_max_length
90
+ if model_max_length is not None:
91
+ context_length_extension(llm_cfg)
92
+
93
+ llm = AutoModelForCausalLM.from_pretrained(
94
+ model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs
95
+ )
96
+
97
+ llm_path = model_name_or_path
98
+ if not has_tokenizer(llm_path):
99
+ warnings.warn("tokenizer found in VLM root folder. Move to ./{VILA}/llm in the future.")
100
+ llm_path = osp.join(llm_path, "llm")
101
+
102
+ # TODO(ligeng): use LLM class to judge to better compability.
103
+ if "mpt" in model_name_or_path:
104
+ tokenizer = AutoTokenizer.from_pretrained(
105
+ llm_path,
106
+ model_max_length=llm_cfg.model_max_length,
107
+ padding_side="right",
108
+ )
109
+ elif "yi" in model_name_or_path.lower():
110
+ tokenizer = AutoTokenizer.from_pretrained(
111
+ llm_path,
112
+ model_max_length=llm_cfg.model_max_length,
113
+ padding_side="right",
114
+ use_fast=False,
115
+ )
116
+ else:
117
+ tokenizer = AutoTokenizer.from_pretrained(
118
+ llm_path,
119
+ model_max_length=llm_cfg.model_max_length,
120
+ padding_side="right",
121
+ use_fast=False,
122
+ legacy=False,
123
+ )
124
+
125
+ # TODO(ligeng): is this necessary for llava?
126
+ config.hidden_size = llm.config.hidden_size
127
+ return llm, tokenizer
dam/model/language_model/llava_llama.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
16
+
17
+
18
+ from typing import List, Optional, Tuple, Union
19
+ import os
20
+ import torch
21
+
22
+ from transformers import (
23
+ AutoConfig,
24
+ AutoModel,
25
+ PretrainedConfig,
26
+ PreTrainedModel,
27
+ )
28
+
29
+ from transformers.modeling_outputs import CausalLMOutputWithPast
30
+ from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
31
+ from ..configuration_llava import LlavaConfig
32
+ from ..mm_utils import get_model_name_from_path, tokenizer_image_token
33
+
34
+ class LlavaLlamaConfig(LlavaConfig):
35
+ model_type = "llava_llama"
36
+
37
+ ## FIXME we will follow the convention to add a new class for CausalLM in the future
38
+ class LlavaLlamaModel(LlavaMetaModel, LlavaMetaForCausalLM, PreTrainedModel):
39
+ config_class = LlavaLlamaConfig
40
+ main_input_name = "input_embeds"
41
+ supports_gradient_checkpointing = True
42
+ tokenizer_image_token = staticmethod(tokenizer_image_token)
43
+
44
+ def __init__(self, config: LlavaLlamaConfig = None, *args, **kwargs) -> None:
45
+ super().__init__(config)
46
+ self.dam_model = None
47
+ self.pretrained_model_name_or_path = None
48
+ self.init_vlm(config=config, *args, **kwargs)
49
+
50
+ @classmethod
51
+ def from_pretrained(
52
+ cls,
53
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
54
+ *model_args,
55
+ config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
56
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
57
+ ignore_mismatched_sizes: bool = False,
58
+ force_download: bool = False,
59
+ local_files_only: bool = False,
60
+ token: Optional[Union[str, bool]] = None,
61
+ revision: str = "main",
62
+ use_safetensors: bool = None,
63
+ torch_dtype: Optional[Union[str, torch.dtype]] = torch.float16,
64
+ init_dam: bool = False,
65
+ # conv_mode and prompt_mode are only used by `init_dam` in `from_pretrained` if `init_dam` is set to True
66
+ conv_mode: str = "v1",
67
+ prompt_mode: str = "full+focal_crop",
68
+ **kwargs,
69
+ ):
70
+ if torch_dtype:
71
+ config.model_dtype = str(torch_dtype)
72
+ if hasattr(cls, "load_pretrained"):
73
+ obj = cls.load_pretrained(pretrained_model_name_or_path,
74
+ *model_args, config=config, cache_dir=cache_dir, ignore_mismatched_sizes=ignore_mismatched_sizes, force_download=force_download, local_files_only=local_files_only, token=token,
75
+ revision=revision, use_safetensors=use_safetensors, **kwargs
76
+ )
77
+ else:
78
+ obj = super(LlavaLlamaModel).from_pretrained(pretrained_model_name_or_path,
79
+ *model_args, config=config, cache_dir=cache_dir, ignore_mismatched_sizes=ignore_mismatched_sizes, force_download=force_download, local_files_only=local_files_only, token=token,
80
+ revision=revision, use_safetensors=use_safetensors, **kwargs)
81
+ obj.pretrained_model_name_or_path = pretrained_model_name_or_path
82
+
83
+ # `init_dam` is used to initialize a `DescribeAnythingModel` object in a `LlavaLlamaModel` in DAM. If you initialize `DescribeAnythingModel` on your own outside, then you don't have to use this option.
84
+ # This is very useful if you use `from_pretrained` with remote code execution and don't want to put implementation for `DescribeAnythingModel` class in your codebase.
85
+ if init_dam:
86
+ obj.init_dam(conv_mode, prompt_mode)
87
+
88
+ return obj
89
+
90
+ def forward(
91
+ self,
92
+ input_ids: torch.LongTensor = None,
93
+ images: Optional[torch.FloatTensor] = None,
94
+ attention_mask: Optional[torch.Tensor] = None,
95
+ position_ids: Optional[torch.LongTensor] = None,
96
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
97
+ inputs_embeds: Optional[torch.FloatTensor] = None,
98
+ labels: Optional[torch.LongTensor] = None,
99
+ use_cache: Optional[bool] = None,
100
+ output_attentions: Optional[bool] = None,
101
+ output_hidden_states: Optional[bool] = None,
102
+ return_dict: Optional[bool] = None,
103
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
104
+ self.freezed_module_patch()
105
+ if inputs_embeds is None:
106
+ (
107
+ input_ids,
108
+ position_ids,
109
+ attention_mask,
110
+ past_key_values,
111
+ inputs_embeds,
112
+ labels,
113
+ ) = self.prepare_inputs_labels_for_multimodal(
114
+ input_ids, position_ids, attention_mask, past_key_values, labels, images
115
+ )
116
+ # Note (kentang-mit@): we have a unit test for this function.
117
+ if self.training:
118
+ (
119
+ _,
120
+ new_position_ids,
121
+ new_attention_mask,
122
+ _,
123
+ new_inputs_embeds,
124
+ new_labels,
125
+ sorted_seqlens_in_batch,
126
+ ) = self.repack_multimodal_data(
127
+ input_ids,
128
+ position_ids,
129
+ attention_mask,
130
+ past_key_values,
131
+ inputs_embeds,
132
+ labels,
133
+ )
134
+ new_input_ids = None
135
+ past_key_values = None
136
+ else:
137
+ new_attention_mask = attention_mask
138
+ new_position_ids = position_ids
139
+ new_inputs_embeds = inputs_embeds
140
+ new_labels = labels
141
+ sorted_seqlens_in_batch = attention_mask.sum(-1).int()
142
+ new_input_ids = input_ids
143
+
144
+ outputs = self.llm.forward(
145
+ input_ids=new_input_ids,
146
+ attention_mask=new_attention_mask,
147
+ position_ids=new_position_ids,
148
+ past_key_values=past_key_values,
149
+ inputs_embeds=new_inputs_embeds,
150
+ labels=new_labels,
151
+ use_cache=use_cache,
152
+ output_attentions=output_attentions,
153
+ output_hidden_states=output_hidden_states,
154
+ return_dict=return_dict,
155
+ seqlens_in_batch=sorted_seqlens_in_batch,
156
+ )
157
+ return outputs
158
+
159
+ @torch.no_grad()
160
+ def generate(
161
+ self,
162
+ input_ids: Optional[torch.FloatTensor] = None,
163
+ images: Optional[torch.FloatTensor] = None,
164
+ attention_mask: Optional[torch.LongTensor] = None,
165
+ **generation_kwargs,
166
+ ):
167
+ if images is not None:
168
+ (
169
+ _,
170
+ _,
171
+ attention_mask,
172
+ _,
173
+ inputs_embeds,
174
+ _,
175
+ ) = self.prepare_inputs_labels_for_multimodal(
176
+ input_ids, None, attention_mask, None, None, images
177
+ )
178
+ else:
179
+ inputs_embeds = self.get_input_embeddings()(input_ids)
180
+ inputs_embeds = inputs_embeds.to(self.dtype)
181
+
182
+ outputs = self.llm.generate(
183
+ inputs_embeds=inputs_embeds,
184
+ attention_mask=attention_mask,
185
+ **generation_kwargs
186
+ )
187
+ return outputs
188
+
189
+ def init_dam(self, conv_mode, prompt_mode):
190
+ from ...describe_anything_model import DescribeAnythingModel
191
+
192
+ model_name = get_model_name_from_path(self.pretrained_model_name_or_path)
193
+ self.dam_model = DescribeAnythingModel(model_path=dict(model=self, tokenizer=self.tokenizer, model_name=model_name), conv_mode=conv_mode, prompt_mode=prompt_mode)
194
+
195
+ return self.dam_model
196
+
197
+ @property
198
+ def dam(self):
199
+ if self.dam_model is None:
200
+ self.init_dam()
201
+ return self.dam_model
202
+
203
+ AutoConfig.register("llava_llama", LlavaLlamaConfig)
204
+ AutoModel.register(LlavaLlamaConfig, LlavaLlamaModel)
dam/model/llava_arch.py ADDED
@@ -0,0 +1,673 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import os.path as osp
17
+ import warnings
18
+ from abc import ABC
19
+
20
+ import torch
21
+ import logging
22
+
23
+ from transformers import (
24
+ AutoConfig,
25
+ PreTrainedModel,
26
+ )
27
+
28
+ from .constants import (
29
+ DEFAULT_IM_END_TOKEN,
30
+ DEFAULT_IM_START_TOKEN,
31
+ DEFAULT_IMAGE_PATCH_TOKEN,
32
+ IGNORE_INDEX,
33
+ IMAGE_TOKEN_INDEX,
34
+ )
35
+
36
+ from collections import OrderedDict
37
+ from .utils import get_model_config
38
+ from .language_model.builder import build_llm_and_tokenizer
39
+ from .multimodal_encoder.builder import build_vision_tower, build_context_provider
40
+ from .multimodal_projector.builder import build_mm_projector
41
+ from .configuration_llava import LlavaConfig
42
+
43
+ from transformers.modeling_utils import ContextManagers, no_init_weights
44
+
45
+ ## TODO decide whether should we use metaclass
46
+ class LlavaMetaModel(ABC):
47
+ def init_vlm(self, config: PreTrainedModel = None, *args, **kwargs):
48
+ # TODO(ligeng): figure out how from_config and from_pretrained works in HF implementation.
49
+ if hasattr(self, "llm") or hasattr(self, "vision_tower") or hasattr(self, "mm_projector"):
50
+ # already initialized, skipped
51
+ return
52
+
53
+ model_dtype = getattr(config, "model_dtype", "torch.float16")
54
+ if not hasattr(config, "model_dtype"):
55
+ warnings.warn("model_dtype not found in config, defaulting to torch.float16.")
56
+ config.model_dtype = model_dtype
57
+
58
+ # print("init_vlm(): config", config); input("DEBUG init_vlm")
59
+ cfgs = get_model_config(config)
60
+ # Only the first three are required. Others are optional.
61
+ llm_cfg, vision_tower_cfg, mm_projector_cfg, mask_encoder_cfg, context_provider_cfg = cfgs
62
+ if llm_cfg is None or vision_tower_cfg is None or mm_projector_cfg is None:
63
+ raise ValueError("`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config.")
64
+ # print("init_vlm():", cfgs); input("DEBUG init_vlm")
65
+ # print(llm_cfg, vision_tower_cfg, mm_projector_cfg); input("DEBUG init_vlm")
66
+ self.llm, self.tokenizer = build_llm_and_tokenizer(llm_cfg, config, *args, **kwargs)
67
+ self.vision_tower = build_vision_tower(vision_tower_cfg, config)
68
+ self.mm_projector = build_mm_projector(mm_projector_cfg, config)
69
+ self.context_provider = build_context_provider(context_provider_cfg, config) if context_provider_cfg is not None else None
70
+
71
+ self.post_config()
72
+ self.is_loaded = True
73
+
74
+ assert (
75
+ self.llm is not None or self.vision_tower is not None or self.mm_projector is not None
76
+ ), "At least one of the components must be instantiated."
77
+
78
+ @classmethod
79
+ def load_from_config(cls, model_path_or_config, *args, **kwargs):
80
+ pass
81
+
82
+ ## FIXME we will use this function to load model in the future
83
+ @classmethod
84
+ def load_pretrained(cls, model_path_or_config, *args, **kwargs):
85
+ config = kwargs.pop("config", None)
86
+
87
+ if config is None:
88
+ if isinstance(model_path_or_config, str):
89
+ config = AutoConfig.from_pretrained(model_path_or_config)
90
+ elif isinstance(model_path_or_config, LlavaConfig):
91
+ config = model_path_or_config
92
+ else:
93
+ raise NotImplementedError(f"wrong type, {type(model_path_or_config)} \
94
+ {isinstance(model_path_or_config, LlavaConfig)}")
95
+
96
+ model_dtype = getattr(config, "model_dtype", "torch.float16")
97
+ if not hasattr(config, "model_dtype"):
98
+ warnings.warn("model_dtype not found in config, defaulting to torch.float16.")
99
+ config.model_dtype = model_dtype
100
+
101
+ cfgs = get_model_config(config)
102
+ # Only the first three are required. Others are optional.
103
+ llm_cfg, vision_tower_cfg, mm_projector_cfg, mask_encoder_cfg, context_provider_cfg = cfgs
104
+ if llm_cfg is None or vision_tower_cfg is None or mm_projector_cfg is None:
105
+ raise ValueError("`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config.")
106
+
107
+ # print(llm_cfg, vision_tower_cfg, mm_projector_cfg); input("DEBUG load_pretrained")
108
+ with ContextManagers([no_init_weights(_enable=True),]):
109
+ vlm = cls(config, *args, **kwargs)
110
+ # print(llm_cfg, vision_tower_cfg, mm_projector_cfg); input("DEBUG load_pretrained finish")
111
+
112
+ if hasattr(vlm, "llm") or hasattr(vlm, "vision_tower") or hasattr(vlm, "mm_projector"):
113
+ if vlm.is_loaded:
114
+ return vlm
115
+
116
+ vlm.llm, vlm.tokenizer = build_llm_and_tokenizer(llm_cfg, config, *args, **kwargs)
117
+ vlm.vision_tower = build_vision_tower(vision_tower_cfg, config)
118
+ vlm.mm_projector = build_mm_projector(mm_projector_cfg, config)
119
+ if mask_encoder_cfg is not None:
120
+ raise NotImplementedError("Mask encoder is not supported.")
121
+ vlm.context_provider = build_context_provider(context_provider_cfg, config) if context_provider_cfg is not None else None
122
+
123
+ self.post_config()
124
+ self.is_loaded = True
125
+
126
+ # FIXME(ligeng, yunhao): llm should never be none here.
127
+ assert (
128
+ vlm.llm is not None or vlm.vision_tower is not None or vlm.mm_projector is not None
129
+ ), "At least one of the components must be instantiated."
130
+ return vlm
131
+
132
+ ## FIXME we will use this function to save the model in the future
133
+ def save_pretrained(self, output_dir, state_dict=None):
134
+ if state_dict is None:
135
+ # other wise fetch from deepspeed
136
+ # state_dict = accelerator.get_state_dict(is_deepspeed_enabled)
137
+ state_dict = self.state_dict()
138
+
139
+ if getattr(self, "tokenizer", None):
140
+ self.tokenizer.save_pretrained(osp.join(output_dir, "llm"))
141
+
142
+ if self.get_llm():
143
+ print(f"saving llm to {osp.join(output_dir, 'llm')}")
144
+ self.llm.config._name_or_path = osp.join(output_dir, "llm")
145
+ llm_state_dict = OrderedDict({k.split("llm.")[-1]: v for k, v in state_dict.items() if "llm" in k})
146
+ self.llm.save_pretrained(os.path.join(output_dir, "llm"), state_dict=llm_state_dict)
147
+ self.config.llm_cfg = self.llm.config
148
+
149
+ if self.get_vision_tower() and "radio" not in self.get_vision_tower().__class__.__name__.lower():
150
+ print(f"saving vision_tower to {osp.join(output_dir, 'vision_tower')}")
151
+ self.vision_tower.config._name_or_path = osp.join(output_dir, "vision_tower")
152
+ vision_tower_state_dict = OrderedDict(
153
+ {k.split("vision_tower.vision_tower.")[-1]: v for k, v in state_dict.items() if "vision_tower" in k}
154
+ )
155
+ self.vision_tower.vision_tower.save_pretrained(
156
+ os.path.join(output_dir, "vision_tower"),
157
+ state_dict=vision_tower_state_dict,
158
+ )
159
+ self.vision_tower.image_processor.save_pretrained(os.path.join(output_dir, "vision_tower"))
160
+ self.config.vision_tower_cfg = self.vision_tower.config
161
+ if hasattr(self.config.vision_tower_cfg, 'auto_map'):
162
+ delattr(self.config.vision_tower_cfg, 'auto_map')
163
+
164
+ if self.get_mm_projector():
165
+ print(f"saving mm_projector to {osp.join(output_dir, 'mm_projector')}")
166
+ self.mm_projector.config._name_or_path = osp.join(output_dir, "mm_projector")
167
+ mm_projector_state_dict = OrderedDict(
168
+ {k.split("mm_projector.")[-1]: v for k, v in state_dict.items() if "mm_projector" in k}
169
+ )
170
+ self.mm_projector.save_pretrained(
171
+ os.path.join(output_dir, "mm_projector"),
172
+ state_dict=mm_projector_state_dict,
173
+ )
174
+ self.config.mm_projector_cfg = self.mm_projector.config
175
+
176
+ if self.get_context_provider():
177
+ print(f"saving context_provider to {osp.join(output_dir, 'context_provider')}")
178
+ self.context_provider.config._name_or_path = osp.join(output_dir, "context_provider")
179
+ context_provider_state_dict = OrderedDict(
180
+ {k.split("context_provider.")[-1]: v for k, v in state_dict.items() if "context_provider" in k}
181
+ )
182
+ self.context_provider.save_pretrained(
183
+ os.path.join(output_dir, "context_provider"),
184
+ state_dict=context_provider_state_dict,
185
+ )
186
+ self.config.context_provider_cfg = self.context_provider.config
187
+
188
+ ## update and save top-level config
189
+ self.config._name_or_path = output_dir
190
+ self.config.architectures = [self.__class__.__name__]
191
+ self.config.save_pretrained(output_dir)
192
+
193
+
194
+ def get_llm(self):
195
+ llm = getattr(self, "llm", None)
196
+ if type(llm) is list:
197
+ llm = llm[0]
198
+ return llm
199
+
200
+ def get_lm_head(self):
201
+ lm_head = getattr(self.get_llm(), "lm_head", None)
202
+ return lm_head
203
+
204
+ def get_vision_tower(self):
205
+ vision_tower = getattr(self, "vision_tower", None)
206
+ if type(vision_tower) is list:
207
+ vision_tower = vision_tower[0]
208
+ return vision_tower
209
+
210
+ def get_mm_projector(self):
211
+ mm_projector = getattr(self, "mm_projector", None)
212
+ if type(mm_projector) is list:
213
+ mm_projector = mm_projector[0]
214
+ return mm_projector
215
+
216
+ def get_context_provider(self):
217
+ context_provider = getattr(self, "context_provider", None)
218
+ return context_provider
219
+
220
+ def post_config(self):
221
+ self.training = self.get_llm().training
222
+ ## configuration
223
+ if getattr(self.config, "llm_cfg", None) is None:
224
+ self.config.llm_cfg = self.llm.config
225
+ if getattr(self.config, "vision_tower_cfg", None) is None:
226
+ self.config.vision_tower_cfg = self.vision_tower.config
227
+ if getattr(self.config, "mm_projector_cfg", None) is None:
228
+ self.config.mm_projector_cfg = self.mm_projector.config
229
+ if getattr(self.config, "context_provider_cfg", None) is None and self.context_provider is not None:
230
+ self.config.context_provider_cfg = self.context_provider.config
231
+
232
+ def freezed_module_patch(self):
233
+ '''
234
+ Huggingface will call model.train() at each training_step. To ensure the expected behaviors for modules like dropout, batchnorm, etc., we need to call model.eval() for the freezed modules.
235
+ '''
236
+ if self.training:
237
+ if self.get_llm() and not getattr(self.config, "tune_language_model", False):
238
+ logging.warning("Caution: Your LLM is currently in training mode, ensuring accurate gradient computation. Please be vigilant, particularly regarding BatchNorm and Dropout operations.")
239
+ if self.get_vision_tower() and not getattr(self.config, "tune_vision_tower", False):
240
+ self.get_vision_tower().eval()
241
+ if self.get_mm_projector() and not getattr(self.config, "tune_mm_projector", False):
242
+ self.get_mm_projector().eval()
243
+ if self.get_context_provider() and not getattr(self.config, "tune_context_provider", False):
244
+ self.get_context_provider().eval()
245
+
246
+ def encode_images(self, images):
247
+ image_features = self.get_vision_tower()(images)
248
+ image_features = self.get_mm_projector()(image_features)
249
+ return image_features
250
+
251
+ def encode_images_with_context(self, images):
252
+ context_provider = self.get_context_provider()
253
+ # If the channels completely match, they are cimage (image with context).
254
+ cimage_mask = torch.any((images[:, :4, ...] != images[:, 4:, ...]).flatten(start_dim=1), dim=1)
255
+
256
+ if context_provider.treat_image_as_cimage:
257
+ # If the context provider treats the image as cimage, then all images are cimage.
258
+ cimage_mask[:] = True
259
+
260
+ if context_provider.context_image_as_queries:
261
+ # Swap the crop image and full image since the model uses the full image as queries by default
262
+ images = torch.cat((images[:, 4:, ...], images[:, :4, ...]), dim=1)
263
+
264
+ # Process the first 4 channels for all images: for image it's the image, for cimage it's the full image
265
+ vision_tower = self.get_vision_tower()
266
+ # Encode context images (full images)
267
+ image_features = vision_tower(images[:, :4, ...]).to(self.device)
268
+ # Each cimage has 8 channels (full and crop concatenated)
269
+ cimage_concatenated = images[cimage_mask]
270
+ cimage_full_features = image_features[cimage_mask]
271
+ if context_provider.context_provider_type == "cross_attn_end_to_all":
272
+ cimage_features = self.context_provider(
273
+ cimage_full_features=cimage_full_features,
274
+ cimage_concatenated=cimage_concatenated,
275
+ vision_tower=vision_tower
276
+ ).to(self.device)
277
+ elif context_provider.context_provider_type == "concat":
278
+ # Full features of cimages are computed but not used.
279
+ cimage_features = self.context_provider(
280
+ cimage_concatenated=cimage_concatenated,
281
+ vision_tower=vision_tower
282
+ ).to(self.device)
283
+ else:
284
+ raise NotImplementedError(f"Context provider type {context_provider.context_provider_type} not implemented.")
285
+ # Put cimage_features into image_features
286
+ image_features[cimage_mask] = cimage_features
287
+
288
+ # Project to the llm space
289
+ image_features = self.get_mm_projector()(image_features)
290
+
291
+ return image_features
292
+
293
+ ## @yunhao: is there a better way to handle function call and attributes for llm?
294
+ ## support beam search
295
+ def _temporary_reorder_cache(self, past_key_values, sorted_idx):
296
+ return self.get_llm()._temporary_reorder_cache(past_key_values, sorted_idx)
297
+
298
+ def get_input_embeddings(self):
299
+ return self.get_llm().get_input_embeddings()
300
+
301
+ def get_output_embeddings(self):
302
+ return self.get_llm().get_output_embeddings()
303
+
304
+ def resize_token_embeddings(self, embed_size):
305
+ self.get_llm().resize_token_embeddings(embed_size)
306
+
307
+
308
+
309
+ class LlavaMetaForCausalLM(ABC):
310
+ """This class is originally implemented by the LLaVA team and
311
+ modified by Haotian Tang and Jason Lu based on Ji Lin's implementation
312
+ to support multiple images and input packing."""
313
+
314
+ ## TODO move the forward function here if there is no need to override it
315
+ def prepare_inputs_labels_for_multimodal(
316
+ self, input_ids, position_ids, attention_mask, past_key_values, labels, images
317
+ ):
318
+ vision_tower = self.get_vision_tower()
319
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
320
+ if (
321
+ past_key_values is not None
322
+ and vision_tower is not None
323
+ and images is not None
324
+ and input_ids.shape[1] == 1
325
+ ):
326
+ target_shape = past_key_values[-1][-1].shape[-2] + 1
327
+ attention_mask = torch.cat(
328
+ (
329
+ attention_mask,
330
+ torch.ones(
331
+ (
332
+ attention_mask.shape[0],
333
+ target_shape - attention_mask.shape[1],
334
+ ),
335
+ dtype=attention_mask.dtype,
336
+ device=attention_mask.device,
337
+ ),
338
+ ),
339
+ dim=1,
340
+ )
341
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
342
+ return (
343
+ input_ids,
344
+ position_ids,
345
+ attention_mask,
346
+ past_key_values,
347
+ None,
348
+ labels,
349
+ )
350
+ # handle different image dtypes for packing
351
+ if type(images) is list:
352
+ images = torch.cat(images, dim=0)
353
+ elif images.ndim == 5: # batch_size x seq_len x image_channels
354
+ images = images.flatten(0, 1)
355
+ if getattr(self, "context_provider", None):
356
+ image_features = self.encode_images_with_context(images)
357
+ else:
358
+ # Since we slice it with index below, turning it into a list splits things by the first index which does not result in data copy or degrade performance.
359
+ # Example dimension: [1, 196, 2560]
360
+ assert images.shape[1] <= 4, "images have more than 4 channels, but context provider is not included"
361
+ image_features = self.encode_images(images).to(self.device)
362
+ # Note (kentang-mit@): image start / end is not implemented here to support pretraining.
363
+ if getattr(self.config, "turn_mm_projector", False) and getattr(self.config, "mm_use_im_start_end", False):
364
+ raise NotImplementedError
365
+
366
+ # Let's just add dummy tensors if they do not exist,
367
+ # it is a headache to deal with None all the time.
368
+ # But it is not ideal, and if you have a better idea,
369
+ # please open an issue / submit a PR, thanks.
370
+ _labels = labels
371
+ _position_ids = position_ids
372
+ _attention_mask = attention_mask
373
+ if attention_mask is None:
374
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
375
+ else:
376
+ attention_mask = attention_mask.bool()
377
+ if position_ids is None:
378
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
379
+ if labels is None:
380
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
381
+
382
+ # remove the padding using attention_mask
383
+ input_ids_copy = input_ids.clone()
384
+ # kentang-mit@: Otherwise tokenizer out of bounds. Embeddings of image tokens will not be used.
385
+ input_ids_copy[input_ids_copy == IMAGE_TOKEN_INDEX] = 0
386
+ input_embeds = self.llm.model.embed_tokens(input_ids_copy)
387
+
388
+ input_ids = [
389
+ cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
390
+ ]
391
+ input_embeds_1 = [
392
+ cur_input_embeds[cur_attention_mask]
393
+ for cur_input_embeds, cur_attention_mask in zip(input_embeds, attention_mask)
394
+ ]
395
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
396
+
397
+ new_input_embeds = []
398
+ new_labels = []
399
+ cur_image_idx = 0
400
+
401
+ # print("BEFORE BATCH LOOP:", len(input_ids), input_ids[0].shape, input_ids[0].device, [(x == IMAGE_TOKEN_INDEX).sum() for x in input_ids])
402
+
403
+ # kentang-mit@: If some part of the model is executed in the loop, the the loop length needs to be a constant.
404
+ for batch_idx, cur_input_ids in enumerate(input_ids):
405
+ cur_input_ids = input_ids[batch_idx]
406
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
407
+ if num_images == 0:
408
+ cur_image_features = image_features[0]
409
+ # cur_input_embeds_1 = self.get_llm().embed_tokens(cur_input_ids)
410
+ cur_input_embeds_1 = input_embeds_1[batch_idx]
411
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
412
+ new_input_embeds.append(cur_input_embeds)
413
+ new_labels.append(labels[batch_idx])
414
+ # kenang-mit@: we do not have placeholdr image for text-only data now.
415
+ # cur_image_idx += 1
416
+ continue
417
+
418
+ cur_input_embeds = input_embeds_1[batch_idx]
419
+ image_token_indices = (
420
+ [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
421
+ )
422
+ cur_input_ids_noim = []
423
+ cur_labels = labels[batch_idx]
424
+ cur_labels_noim = []
425
+ cur_input_embeds_no_im = []
426
+ for i in range(len(image_token_indices) - 1):
427
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
428
+ cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
429
+ cur_input_embeds_no_im.append(cur_input_embeds[image_token_indices[i] + 1 : image_token_indices[i + 1]])
430
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
431
+ # cur_input_embeds = self.get_llm().embed_tokens(torch.cat(cur_input_ids_noim))
432
+ # cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
433
+ cur_new_input_embeds = []
434
+ cur_new_labels = []
435
+ for i in range(num_images + 1):
436
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
437
+ cur_new_labels.append(cur_labels_noim[i])
438
+ if i < num_images:
439
+ cur_image_features = image_features[cur_image_idx]
440
+ cur_image_idx += 1
441
+ cur_new_input_embeds.append(cur_image_features)
442
+ cur_new_labels.append(
443
+ torch.full(
444
+ (cur_image_features.shape[0],),
445
+ IGNORE_INDEX,
446
+ device=cur_labels.device,
447
+ dtype=cur_labels.dtype,
448
+ )
449
+ )
450
+
451
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
452
+ cur_new_labels = torch.cat(cur_new_labels)
453
+
454
+ new_input_embeds.append(cur_new_input_embeds)
455
+ new_labels.append(cur_new_labels)
456
+
457
+ # Truncate sequences to max length as image embeddings can make the sequence longer
458
+ tokenizer_model_max_length = getattr(self.llm.config, "tokenizer_model_max_length", None)
459
+ if tokenizer_model_max_length is not None:
460
+ if any(len(x) > tokenizer_model_max_length for x in new_input_embeds):
461
+ warnings.warn("Inputs truncated!")
462
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
463
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
464
+ # Combine them
465
+ max_len = max(x.shape[0] for x in new_input_embeds)
466
+ batch_size = len(new_input_embeds)
467
+
468
+ new_input_embeds_padded = []
469
+ new_labels_padded = torch.full(
470
+ (batch_size, max_len),
471
+ IGNORE_INDEX,
472
+ dtype=new_labels[0].dtype,
473
+ device=new_labels[0].device,
474
+ )
475
+ attention_mask = torch.zeros(
476
+ (batch_size, max_len),
477
+ dtype=attention_mask.dtype,
478
+ device=attention_mask.device,
479
+ )
480
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
481
+
482
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
483
+ cur_len = cur_new_embed.shape[0]
484
+ if getattr(self.llm.config, "tokenizer_padding_side", "right") == "left":
485
+ new_input_embeds_padded.append(
486
+ torch.cat(
487
+ (
488
+ torch.zeros(
489
+ (max_len - cur_len, cur_new_embed.shape[1]),
490
+ dtype=cur_new_embed.dtype,
491
+ device=cur_new_embed.device,
492
+ ),
493
+ cur_new_embed,
494
+ ),
495
+ dim=0,
496
+ )
497
+ )
498
+ if cur_len > 0:
499
+ new_labels_padded[i, -cur_len:] = cur_new_labels
500
+ attention_mask[i, -cur_len:] = True
501
+ position_ids[i, -cur_len:] = torch.arange(
502
+ 0, cur_len, dtype=position_ids.dtype, device=position_ids.device
503
+ )
504
+ else:
505
+ new_input_embeds_padded.append(
506
+ torch.cat(
507
+ (
508
+ cur_new_embed,
509
+ torch.zeros(
510
+ (max_len - cur_len, cur_new_embed.shape[1]),
511
+ dtype=cur_new_embed.dtype,
512
+ device=cur_new_embed.device,
513
+ ),
514
+ ),
515
+ dim=0,
516
+ )
517
+ )
518
+ if cur_len > 0:
519
+ new_labels_padded[i, :cur_len] = cur_new_labels
520
+ attention_mask[i, :cur_len] = True
521
+ position_ids[i, :cur_len] = torch.arange(
522
+ 0, cur_len, dtype=position_ids.dtype, device=position_ids.device
523
+ )
524
+
525
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
526
+
527
+ if _labels is None:
528
+ new_labels = None
529
+ else:
530
+ new_labels = new_labels_padded
531
+
532
+ if _attention_mask is None:
533
+ attention_mask = None
534
+ else:
535
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
536
+
537
+ if _position_ids is None:
538
+ position_ids = None
539
+
540
+ return (
541
+ None,
542
+ position_ids,
543
+ attention_mask,
544
+ past_key_values,
545
+ new_input_embeds,
546
+ new_labels,
547
+ )
548
+
549
+ def repack_multimodal_data(
550
+ self,
551
+ input_ids,
552
+ position_ids,
553
+ attention_mask,
554
+ past_key_values,
555
+ inputs_embeds,
556
+ labels,
557
+ ):
558
+ # kentang-mit@: reorder and repack (reduce computation overhead)
559
+ # requires transformers replacement.
560
+ new_inputs_embeds = []
561
+ new_position_ids = []
562
+ new_labels = []
563
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
564
+ sorted_seqlens_in_batch, sorted_idx = torch.sort(seqlens_in_batch, descending=True)
565
+ # print(sorted_seqlens_in_batch)
566
+ max_seqlen = inputs_embeds.shape[1]
567
+
568
+ cur_inputs_embeds = []
569
+ cur_position_ids = []
570
+ cur_labels = []
571
+ cur_batch_len = 0
572
+ # print(sorted_seqlens_in_batch.device, len(sorted_seqlens_in_batch), max_seqlen)
573
+ for i in range(len(sorted_seqlens_in_batch)):
574
+ cur_seqlen = sorted_seqlens_in_batch[i].item()
575
+ if cur_seqlen + cur_batch_len <= max_seqlen:
576
+ cur_batch_len += cur_seqlen
577
+ # each item: num_tokens x num_channels
578
+ # remove padding on-the-fly
579
+ cur_inputs_embeds.append(inputs_embeds[sorted_idx[i]][attention_mask[sorted_idx[i]]])
580
+ # each item: num_tokens
581
+ cur_position_ids.append(
582
+ torch.arange(
583
+ cur_inputs_embeds[-1].shape[0],
584
+ device=cur_inputs_embeds[-1].device,
585
+ )
586
+ )
587
+ # each item: num_tokens
588
+ # remove padding on-the-fly
589
+ cur_labels.append(labels[sorted_idx[i]][attention_mask[sorted_idx[i]]])
590
+ else:
591
+ new_inputs_embeds.append(torch.cat(cur_inputs_embeds, 0))
592
+ new_position_ids.append(torch.cat(cur_position_ids, 0))
593
+ new_labels.append(torch.cat(cur_labels, 0))
594
+ # The current batch is too long. We will start a new batch.
595
+ cur_batch_len = cur_seqlen
596
+ cur_inputs_embeds = [inputs_embeds[sorted_idx[i]][attention_mask[sorted_idx[i]]]]
597
+ cur_position_ids = [
598
+ torch.arange(
599
+ cur_inputs_embeds[-1].shape[0],
600
+ device=cur_inputs_embeds[-1].device,
601
+ )
602
+ ]
603
+ cur_labels = [labels[sorted_idx[i]][attention_mask[sorted_idx[i]]]]
604
+
605
+ if len(cur_inputs_embeds):
606
+ new_inputs_embeds.append(torch.cat(cur_inputs_embeds, 0))
607
+ new_position_ids.append(torch.cat(cur_position_ids, 0))
608
+ new_labels.append(torch.cat(cur_labels, 0))
609
+
610
+ # print(new_position_ids[0].device, [x.shape for x in new_inputs_embeds], [x.shape for x in new_labels], [x.shape for x in new_position_ids])
611
+ # assert 0
612
+ new_inputs_embeds = torch.nn.utils.rnn.pad_sequence(
613
+ new_inputs_embeds, batch_first=True, padding_value=self.llm.pad_token_id
614
+ )
615
+
616
+ new_position_ids = torch.nn.utils.rnn.pad_sequence(new_position_ids, batch_first=True, padding_value=-1)
617
+
618
+ new_labels = torch.nn.utils.rnn.pad_sequence(new_labels, batch_first=True, padding_value=IGNORE_INDEX)
619
+ ## yunhao: it's currently a workaround to avoid errors for seq_len < 100
620
+ new_attention_mask = new_position_ids.ne(-1)
621
+ # sanity check
622
+ assert new_attention_mask.sum() == attention_mask.sum()
623
+ # print(new_inputs_embeds.shape, (new_attention_mask.sum(1)))
624
+ # print(sorted_seqlens_in_batch.device, sorted_seqlens_in_batch, new_attention_mask.sum(1))
625
+
626
+ # return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
627
+ return (
628
+ None,
629
+ new_position_ids,
630
+ new_attention_mask,
631
+ past_key_values,
632
+ new_inputs_embeds,
633
+ new_labels,
634
+ sorted_seqlens_in_batch,
635
+ )
636
+
637
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
638
+ if model_args.mm_use_im_patch_token:
639
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
640
+ self.resize_token_embeddings(len(tokenizer))
641
+
642
+ if model_args.mm_use_im_start_end:
643
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
644
+ self.resize_token_embeddings(len(tokenizer))
645
+
646
+ if num_new_tokens > 0:
647
+ input_embeddings = self.get_input_embeddings().weight.data
648
+ output_embeddings = self.get_output_embeddings().weight.data
649
+
650
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
651
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
652
+
653
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
654
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
655
+ ## TODO yunhao: handle cases for <im_st> <im_end>
656
+ if model_args.pretrain_mm_mlp_adapter:
657
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location="cpu")
658
+ embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
659
+ assert num_new_tokens == 2
660
+ if input_embeddings.shape == embed_tokens_weight.shape:
661
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
662
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
663
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
664
+ else:
665
+ raise ValueError(
666
+ f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}."
667
+ )
668
+ elif model_args.mm_use_im_patch_token:
669
+ if model_args.mm_projector:
670
+ for p in self.get_input_embeddings().parameters():
671
+ p.requires_grad = False
672
+ for p in self.get_output_embeddings().parameters():
673
+ p.requires_grad = False
dam/model/mm_utils.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ from PIL import Image
18
+ from io import BytesIO
19
+ import base64
20
+ import numpy as np
21
+ import os
22
+
23
+ import torch
24
+ from transformers import StoppingCriteria
25
+ from .constants import IMAGE_TOKEN_INDEX
26
+
27
+ import tempfile
28
+
29
+ def get_frame_from_vcap(vidcap, num_frames=10, fps=None, frame_count=None):
30
+ import cv2
31
+
32
+ if fps == None or frame_count == None:
33
+ # if one of fps or frame_count is None, still recompute
34
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
35
+ frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
36
+ if fps == 0 or frame_count == 0:
37
+ print("Video file not found. return empty images.")
38
+ return [
39
+ Image.new("RGB", (720, 720)),
40
+ ] * num_frames
41
+
42
+ duration = frame_count / fps
43
+ frame_interval = frame_count // num_frames
44
+ if frame_interval == 0 and frame_count <= 1:
45
+ print("frame_interval is equal to 0. return empty image.")
46
+ return [
47
+ Image.new("RGB", (720, 720)),
48
+ ] * num_frames
49
+ # print("duration:", duration, "frames:", frame_count, "intervals:", frame_interval)
50
+
51
+ images = []
52
+ count = 0
53
+ success = True
54
+ frame_indices = np.linspace(0, frame_count - 2, num_frames, dtype=int)
55
+
56
+ while success:
57
+ # print("frame_count:", frame_count, "count:", count, "num_frames:", num_frames, "frame_interval:", frame_interval)
58
+ if frame_count >= num_frames:
59
+ success, frame = vidcap.read()
60
+ if count in frame_indices:
61
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
62
+ im_pil = Image.fromarray(img)
63
+ images.append(im_pil)
64
+ if len(images) >= num_frames:
65
+ return images
66
+ count += 1
67
+ else:
68
+ # Left padding frames if the video is not long enough
69
+ success, frame = vidcap.read()
70
+ if success:
71
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
72
+ im_pil = Image.fromarray(img)
73
+ images.append(im_pil)
74
+ count += 1
75
+ elif count >= 1:
76
+ width, height = images[-1].size
77
+ images = [Image.new("RGB", (width, height))] * \
78
+ (num_frames - len(images)) + images
79
+ print("padding frames:", (num_frames - len(images)))
80
+ return images
81
+ else:
82
+ break
83
+ raise ValueError(
84
+ "Did not find enough frames in the video. return empty image.")
85
+
86
+
87
+ def opencv_extract_frames(vpath_or_bytesio, frames=6, fps=None, frame_count=None):
88
+ """
89
+ Extract frames from a video using OpenCV.
90
+
91
+ Args:
92
+ vpath_or_bytesio (str or BytesIO): Path to the video file or BytesIO object containing the video.
93
+ frames (int): Number of frames to extract from the video.
94
+
95
+ Returns:
96
+ list: List of PIL Images extracted from the video.
97
+
98
+ Raises:
99
+ NotImplementedError: If the type of `vpath_or_bytesio` is not supported.
100
+ """
101
+ import cv2
102
+
103
+ if isinstance(vpath_or_bytesio, str):
104
+ vidcap = cv2.VideoCapture(vpath_or_bytesio)
105
+ return get_frame_from_vcap(vidcap, frames, fps=fps, frame_count=frame_count)
106
+ elif isinstance(vpath_or_bytesio, (BytesIO,)):
107
+ # assuming mp4
108
+ with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video:
109
+ temp_video.write(vpath_or_bytesio.read())
110
+ temp_video_name = temp_video.name
111
+ vidcap = cv2.VideoCapture(temp_video_name)
112
+ return get_frame_from_vcap(vidcap, frames, fps=fps, frame_count=frame_count)
113
+ else:
114
+ raise NotImplementedError(type(vpath_or_bytesio))
115
+
116
+
117
+ def load_image_from_base64(image):
118
+ return Image.open(BytesIO(base64.b64decode(image)))
119
+
120
+
121
+ def expand2square(pil_img, background_color):
122
+ """
123
+ Expand the given PIL image to a square shape by adding padding.
124
+
125
+ Parameters:
126
+ - pil_img: The PIL image to be expanded.
127
+ - background_color: The color of the padding to be added.
128
+
129
+ Returns:
130
+ - The expanded PIL image.
131
+
132
+ If the image is already square, it is returned as is.
133
+ If the image is wider than it is tall, padding is added to the top and bottom.
134
+ If the image is taller than it is wide, padding is added to the left and right.
135
+ """
136
+ width, height = pil_img.size
137
+ if pil_img.mode == 'L':
138
+ background_color = background_color[0]
139
+ if width == height:
140
+ return pil_img
141
+ elif width > height:
142
+ result = Image.new(pil_img.mode, (width, width), background_color)
143
+ result.paste(pil_img, (0, (width - height) // 2))
144
+ return result
145
+ else:
146
+ result = Image.new(pil_img.mode, (height, height), background_color)
147
+ result.paste(pil_img, ((height - width) // 2, 0))
148
+ return result
149
+
150
+
151
+ def process_image(image_file, data_args, image_folder, pil_preprocess_fn=None):
152
+ processor = data_args.image_processor
153
+ if isinstance(image_file, str):
154
+ if image_folder is not None:
155
+ image = Image.open(os.path.join(
156
+ image_folder, image_file)).convert("RGB")
157
+ else:
158
+ image = Image.open(image_file).convert("RGB")
159
+ else:
160
+ # image is stored in bytearray
161
+ image = image_file.convert("RGB")
162
+
163
+ info = None
164
+
165
+ if pil_preprocess_fn is not None:
166
+ image = pil_preprocess_fn(image)
167
+ if isinstance(image, tuple):
168
+ image, info = image
169
+
170
+ if data_args.image_aspect_ratio == "resize":
171
+ if hasattr(data_args.image_processor, "crop_size"):
172
+ # CLIP vision tower
173
+ crop_size = data_args.image_processor.crop_size
174
+ else:
175
+ # SIGLIP vision tower
176
+ assert hasattr(data_args.image_processor, "size")
177
+ crop_size = data_args.image_processor.size
178
+ image = image.resize((crop_size["height"], crop_size["width"]))
179
+ if data_args.image_aspect_ratio == "pad":
180
+
181
+ def expand2square(pil_img, background_color):
182
+ width, height = pil_img.size
183
+ if width == height:
184
+ return pil_img
185
+ elif width > height:
186
+ result = Image.new(
187
+ pil_img.mode, (width, width), background_color)
188
+ result.paste(pil_img, (0, (width - height) // 2))
189
+ return result
190
+ else:
191
+ result = Image.new(
192
+ pil_img.mode, (height, height), background_color)
193
+ result.paste(pil_img, ((height - width) // 2, 0))
194
+ return result
195
+
196
+ image = expand2square(image, tuple(int(x * 255)
197
+ for x in processor.image_mean))
198
+ image = processor.preprocess(image, return_tensors="pt")[
199
+ "pixel_values"][0]
200
+ else:
201
+ # Using default behavior of the vision encoder
202
+ # For CLIP, default is central crop
203
+ # For Radio, default is central crop
204
+ # For Siglip, default is resize
205
+ # For InternVIT, default is resize
206
+ image = processor.preprocess(image, return_tensors="pt")[
207
+ "pixel_values"][0]
208
+ if info is not None:
209
+ return image, info
210
+ return image
211
+
212
+
213
+ def process_images(images, image_processor, model_cfg):
214
+
215
+ model_cfg.image_processor = image_processor
216
+ new_images = [process_image(image, model_cfg, None) for image in images]
217
+
218
+ if all(x.shape == new_images[0].shape for x in new_images):
219
+ new_images = torch.stack(new_images, dim=0)
220
+ return new_images
221
+
222
+
223
+ # Note that newer VILA codebase adds an lstrip option that defaults to False, and the functionality is the same by default
224
+ def tokenizer_image_token(
225
+ prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None
226
+ ):
227
+ prompt_chunks = [
228
+ tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
229
+
230
+ def insert_separator(X, sep):
231
+ return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
232
+
233
+ input_ids = []
234
+ offset = 0
235
+ if (
236
+ len(prompt_chunks) > 0
237
+ and len(prompt_chunks[0]) > 0
238
+ and prompt_chunks[0][0] == tokenizer.bos_token_id
239
+ ):
240
+ offset = 1
241
+ input_ids.append(prompt_chunks[0][0])
242
+
243
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
244
+ input_ids.extend(x[offset:])
245
+
246
+ if return_tensors is not None:
247
+ if return_tensors == "pt":
248
+ return torch.tensor(input_ids, dtype=torch.long)
249
+ raise ValueError(f"Unsupported tensor type: {return_tensors}")
250
+ return input_ids
251
+
252
+
253
+ def is_gemma_tokenizer(tokenizer):
254
+ return "gemma" in tokenizer.__class__.__name__.lower()
255
+
256
+
257
+ def get_model_name_from_path(model_path):
258
+ if not model_path:
259
+ return "describe_anything_model"
260
+ model_path = model_path.strip("/")
261
+ model_paths = model_path.split("/")
262
+ if model_paths[-1].startswith("checkpoint-"):
263
+ return model_paths[-2] + "_" + model_paths[-1]
264
+ else:
265
+ return model_paths[-1]
266
+
267
+
268
+ class KeywordsStoppingCriteria(StoppingCriteria):
269
+ def __init__(self, keywords, tokenizer, input_ids):
270
+ self.keywords = keywords
271
+ self.keyword_ids = []
272
+ self.max_keyword_len = 0
273
+ for keyword in keywords:
274
+ cur_keyword_ids = tokenizer(keyword).input_ids
275
+ if (
276
+ len(cur_keyword_ids) > 1
277
+ and cur_keyword_ids[0] == tokenizer.bos_token_id
278
+ ):
279
+ cur_keyword_ids = cur_keyword_ids[1:]
280
+ if len(cur_keyword_ids) > self.max_keyword_len:
281
+ self.max_keyword_len = len(cur_keyword_ids)
282
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
283
+ self.tokenizer = tokenizer
284
+ self.start_len = input_ids.shape[1]
285
+
286
+ def call_for_batch(
287
+ self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
288
+ ) -> bool:
289
+ offset = min(output_ids.shape[1] -
290
+ self.start_len, self.max_keyword_len)
291
+ self.keyword_ids = [
292
+ keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids
293
+ ]
294
+ for keyword_id in self.keyword_ids:
295
+ if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
296
+ return True
297
+ outputs = self.tokenizer.batch_decode(
298
+ output_ids[:, -offset:], skip_special_tokens=True
299
+ )[0]
300
+ for keyword in self.keywords:
301
+ if keyword in outputs:
302
+ return True
303
+ return False
304
+
305
+ def __call__(
306
+ self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
307
+ ) -> bool:
308
+ outputs = []
309
+ for i in range(output_ids.shape[0]):
310
+ outputs.append(self.call_for_batch(
311
+ output_ids[i].unsqueeze(0), scores))
312
+ return all(outputs)
dam/model/model_utils.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # This file is modified from https://github.com/haotian-liu/LLaVA/ and https://github.com/NVlabs/VILA/
16
+
17
+ import torch
18
+ from transformers import (
19
+ AutoConfig,
20
+ BitsAndBytesConfig,
21
+ PretrainedConfig,
22
+ )
23
+
24
+ from .language_model.llava_llama import LlavaLlamaModel
25
+ # TODO: we may move LlavaConfig to configuration_llava.py
26
+ # from model.configuration_llava import LlavaConfig
27
+
28
+
29
+ def disable_torch_init():
30
+ """
31
+ Disable the redundant torch default initialization to accelerate model creation.
32
+ """
33
+ import torch
34
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
35
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
36
+
37
+
38
+ def load_pretrained_model(
39
+ model_path,
40
+ model_name,
41
+ model_base=None,
42
+ load_8bit=False,
43
+ load_4bit=False,
44
+ device_map="auto",
45
+ device="cuda",
46
+ **kwargs,
47
+ ):
48
+ kwargs = {"device_map": device_map, **kwargs}
49
+
50
+ if device != "cuda":
51
+ kwargs["device_map"] = {"": device}
52
+
53
+ if load_8bit:
54
+ kwargs["load_in_8bit"] = True
55
+ elif load_4bit:
56
+ kwargs["load_in_4bit"] = True
57
+ kwargs["quantization_config"] = BitsAndBytesConfig(
58
+ load_in_4bit=True,
59
+ bnb_4bit_compute_dtype=torch.float16,
60
+ bnb_4bit_use_double_quant=True,
61
+ bnb_4bit_quant_type="nf4",
62
+ )
63
+ else:
64
+ kwargs["torch_dtype"] = torch.float16
65
+
66
+ config = AutoConfig.from_pretrained(model_path)
67
+ config.resume_path = model_path
68
+ prepare_config_for_eval(config, kwargs)
69
+
70
+ model = LlavaLlamaModel(
71
+ config=config,
72
+ low_cpu_mem_usage=True,
73
+ **kwargs
74
+ )
75
+ tokenizer = model.tokenizer
76
+
77
+ model.eval()
78
+
79
+ # mm_use_im_start_end = getattr(
80
+ # model.config, "mm_use_im_start_end", False)
81
+ # mm_use_im_patch_token = getattr(
82
+ # model.config, "mm_use_im_patch_token", True)
83
+ # if mm_use_im_patch_token:
84
+ # tokenizer.add_tokens(
85
+ # [DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
86
+ # if mm_use_im_start_end:
87
+ # tokenizer.add_tokens(
88
+ # [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
89
+ # )
90
+
91
+ model.resize_token_embeddings(len(tokenizer))
92
+ vision_tower = model.get_vision_tower()
93
+ vision_tower.to(device=device, dtype=torch.float16)
94
+ mm_projector = model.get_mm_projector()
95
+ mm_projector.to(device=device, dtype=torch.float16)
96
+ context_provider = model.get_context_provider()
97
+ if context_provider is not None:
98
+ context_provider.to(device=device, dtype=torch.float16)
99
+ image_processor = vision_tower.image_processor
100
+
101
+ if hasattr(model.llm.config, "max_sequence_length"):
102
+ context_len = model.config.max_sequence_length
103
+ else:
104
+ context_len = 2048
105
+
106
+ return tokenizer, model, image_processor, context_len
107
+
108
+
109
+ def parse_model_name_or_path(config: PretrainedConfig, model_name="llm", suffix="_cfg"):
110
+ target_model = f"{model_name}{suffix}"
111
+ target_cfg = getattr(config, target_model, None)
112
+
113
+ if isinstance(target_cfg, str):
114
+ return target_cfg
115
+ elif isinstance(target_cfg, dict):
116
+ return target_cfg["architectures"][0]
117
+ else:
118
+ raise ValueError(f"Invalid {target_model} configuration!")
119
+
120
+
121
+ def prepare_config_for_eval(config: PretrainedConfig, kwargs: dict):
122
+ try:
123
+ # compatible with deprecated config convention
124
+ if getattr(config, "vision_tower_cfg", None) is None:
125
+ config.vision_tower_cfg = config.mm_vision_tower
126
+ except AttributeError:
127
+ raise ValueError(
128
+ f"Invalid configuration! Cannot find vision_tower in config:\n{config}")
129
+
130
+ config.model_dtype = kwargs.pop("torch_dtype").__str__()
131
+ # siglip does not support device_map = "auto"
132
+ vision_tower_name = parse_model_name_or_path(config, "vision_tower")
133
+ if "siglip" in vision_tower_name.lower():
134
+ kwargs["device_map"] = "cuda"
dam/model/multimodal_encoder/builder.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
18
+ import torch # noqa
19
+ import os
20
+ from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
21
+ from .siglip_encoder import SiglipVisionTower
22
+ from .context_provider import ContextProvider, ContextProviderConfig
23
+
24
+ def build_vision_tower(
25
+ model_name_or_path: str, config: PretrainedConfig
26
+ ) -> PreTrainedModel:
27
+ ## skip vision tower instantiation
28
+ if model_name_or_path is None:
29
+ return None
30
+
31
+ vision_tower_arch = None
32
+ if config.resume_path and "radio" not in model_name_or_path:
33
+ assert os.path.exists(
34
+ model_name_or_path
35
+ ), f"Resume vision tower path {model_name_or_path} does not exist!"
36
+ vision_tower_cfg = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
37
+ vision_tower_arch = vision_tower_cfg.architectures[0].lower()
38
+ vision_tower_name = (
39
+ vision_tower_arch if vision_tower_arch is not None else model_name_or_path
40
+ )
41
+
42
+ if "siglip" in vision_tower_name:
43
+ vision_tower = SiglipVisionTower(model_name_or_path, config)
44
+ else:
45
+ raise ValueError(f"Unknown vision tower: {model_name_or_path}")
46
+
47
+ config.mm_hidden_size = vision_tower.config.hidden_size
48
+ return vision_tower
49
+
50
+ def build_context_provider(
51
+ model_type_or_path: str, config: PretrainedConfig
52
+ ) -> PreTrainedModel:
53
+ if model_type_or_path is None:
54
+ return None
55
+
56
+ ## load from pretrained model
57
+ if config.resume_path:
58
+ assert os.path.exists(
59
+ model_type_or_path
60
+ ), f"Resume context provider path {model_type_or_path} does not exist!"
61
+ return ContextProvider.from_pretrained(
62
+ model_type_or_path, config, torch_dtype=eval(config.model_dtype)
63
+ )
64
+ ## build from scratch
65
+ else:
66
+ mm_projector_cfg = ContextProviderConfig(model_type_or_path)
67
+ mm_projector = ContextProvider(mm_projector_cfg, config).to(
68
+ eval(config.model_dtype)
69
+ )
70
+ return mm_projector
dam/model/multimodal_encoder/context_provider.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import torch.nn as nn
18
+ import torch
19
+ import torch.nn.functional as F
20
+ # import deepspeed
21
+ from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
22
+
23
+ class ContextProviderConfig(PretrainedConfig):
24
+ model_type = "context_provider"
25
+
26
+ def __init__(
27
+ self,
28
+ context_provider_type: str=None,
29
+ hidden_size=768,
30
+ intermediate_size=3072,
31
+ num_hidden_layers=12,
32
+ num_attention_heads=12,
33
+ num_channels=3,
34
+ num_mask_channels=0,
35
+ image_size=224,
36
+ patch_size=16,
37
+ hidden_act="gelu_pytorch_tanh",
38
+ layer_norm_eps=1e-6,
39
+ attention_dropout=0.0,
40
+ zero_init_output=True,
41
+ residual_dropout=0.0,
42
+ context_image_as_queries=False,
43
+ context_provider_layer_indices=None,
44
+ masked_cross_attn=False,
45
+ crop_position_single_embedding=False,
46
+ trainable_crop_position_embedding=True,
47
+ crop_embedding_mode="add",
48
+ treat_image_as_cimage=False,
49
+ **kwargs,
50
+ ):
51
+ super().__init__(**kwargs)
52
+
53
+ self.context_provider_type = context_provider_type
54
+
55
+ self.hidden_size = hidden_size
56
+ self.intermediate_size = intermediate_size
57
+ self.num_attention_heads = num_attention_heads
58
+ self.num_channels = num_channels
59
+ self.num_mask_channels = num_mask_channels
60
+ self.patch_size = patch_size
61
+ self.image_size = image_size
62
+ self.attention_dropout = attention_dropout
63
+ self.layer_norm_eps = layer_norm_eps
64
+ self.hidden_act = hidden_act
65
+
66
+ self.zero_init_output = zero_init_output
67
+ self.residual_dropout = residual_dropout
68
+ self.context_image_as_queries = context_image_as_queries
69
+
70
+ # cross_attn_end_to_all
71
+ # the `num_hidden_layers` should be the same as the one in the vision tower
72
+ self.num_hidden_layers = num_hidden_layers
73
+ self.context_provider_layer_indices = context_provider_layer_indices
74
+
75
+ self.masked_cross_attn = masked_cross_attn
76
+ # If enabled, crop_position_embedding (delta to full pos) will be updated during training.
77
+ self.trainable_crop_position_embedding = trainable_crop_position_embedding
78
+ # If enabled, crop_position_embedding (delta to full pos) will be a single embedding for all positions.
79
+ self.crop_position_single_embedding = crop_position_single_embedding
80
+ # add: delta. replace: do not add the original positional embedding
81
+ self.crop_embedding_mode = crop_embedding_mode
82
+
83
+ # If True, the input image will be treated as a cimage (with mask as full 1s)
84
+ self.treat_image_as_cimage = treat_image_as_cimage
85
+
86
+
87
+ # Context Provider
88
+ from transformers.activations import ACT2FN
89
+ from typing import Optional, Tuple
90
+
91
+ class ContextProviderCrossAttention(nn.Module):
92
+ """Multi-headed cross-attention from 'Attention Is All You Need' paper"""
93
+
94
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
95
+ def __init__(self, config):
96
+ super().__init__()
97
+ self.config = config
98
+ self.embed_dim = config.hidden_size
99
+ self.num_heads = config.num_attention_heads
100
+ self.head_dim = self.embed_dim // self.num_heads
101
+ if self.head_dim * self.num_heads != self.embed_dim:
102
+ raise ValueError(
103
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
104
+ f" {self.num_heads})."
105
+ )
106
+ self.scale = self.head_dim**-0.5
107
+ self.dropout = config.attention_dropout
108
+
109
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
110
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
111
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
112
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
113
+
114
+ def forward(
115
+ self,
116
+ hidden_states: torch.Tensor,
117
+ encoder_hidden_states: torch.Tensor,
118
+ attention_mask: Optional[torch.Tensor] = None,
119
+ output_attentions: Optional[bool] = False,
120
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
121
+ """Input shape: Batch x Time x Channel"""
122
+
123
+ batch_size, q_len, _ = hidden_states.size()
124
+ batch_size, kv_len, _ = encoder_hidden_states.size()
125
+
126
+ query_states = self.q_proj(hidden_states)
127
+ key_states = self.k_proj(encoder_hidden_states)
128
+ value_states = self.v_proj(encoder_hidden_states)
129
+
130
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
131
+ key_states = key_states.view(batch_size, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
132
+ value_states = value_states.view(batch_size, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
133
+
134
+ k_v_seq_len = key_states.shape[-2]
135
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
136
+
137
+ if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
138
+ raise ValueError(
139
+ f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
140
+ f" {attn_weights.size()}"
141
+ )
142
+
143
+ if attention_mask is not None:
144
+ if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
145
+ raise ValueError(
146
+ f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
147
+ )
148
+ attn_weights = attn_weights + attention_mask
149
+
150
+ # Visualizations (-inf are shown as white)
151
+ # import matplotlib.pyplot as plt
152
+ # plt.imshow(attention_mask[0, 0, 0].view(27, 27).detach().cpu().numpy())
153
+ # plt.title("Attention mask")
154
+ # plt.colorbar()
155
+ # plt.show()
156
+
157
+ # upcast attention to fp32
158
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
159
+
160
+ # Visualizations: show the attention weights of the first head, with the first query
161
+ # import matplotlib.pyplot as plt
162
+ # plt.imshow(attn_weights[0, 0, 0].view(27, 27).detach().cpu().numpy())
163
+ # plt.title("Attention weights")
164
+ # plt.colorbar()
165
+ # plt.show()
166
+
167
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
168
+ attn_output = torch.matmul(attn_weights, value_states)
169
+
170
+ if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
171
+ raise ValueError(
172
+ f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
173
+ f" {attn_output.size()}"
174
+ )
175
+
176
+ attn_output = attn_output.transpose(1, 2).contiguous()
177
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
178
+
179
+ attn_output = self.out_proj(attn_output)
180
+
181
+ return attn_output, attn_weights
182
+
183
+ class ContextProviderMLP(nn.Module):
184
+ def __init__(self, config):
185
+ super().__init__()
186
+ self.config = config
187
+ self.activation_fn = ACT2FN[config.hidden_act]
188
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
189
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
190
+
191
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
192
+ hidden_states = self.fc1(hidden_states)
193
+ hidden_states = self.activation_fn(hidden_states)
194
+ hidden_states = self.fc2(hidden_states)
195
+ return hidden_states
196
+
197
+
198
+ def get_token_mask_bias(mask, patch_size):
199
+ # Note: mask should be (0, 1)
200
+ with torch.no_grad():
201
+ # Add a channel dimension and perform conv
202
+ # mask_tokens_after_conv: (B, 1, H, W), example dimension: [1, 1, 27, 27]
203
+ mask_tokens_after_conv = F.conv2d(
204
+ input=mask[:, None],
205
+ weight=torch.ones(
206
+ (1, 1, patch_size, patch_size),
207
+ device=mask.device, dtype=mask.dtype
208
+ ),
209
+ bias=None,
210
+ stride=(patch_size, patch_size),
211
+ padding="valid"
212
+ )
213
+
214
+ token_mask_bias = torch.zeros_like(mask_tokens_after_conv)
215
+ token_mask_bias.masked_fill_(mask_tokens_after_conv < 1e-5, float("-inf"))
216
+ token_mask_bias = token_mask_bias.flatten(1)
217
+
218
+ # Flattened dimension: (1, 729)
219
+ return token_mask_bias
220
+
221
+ def attn_mask_from_cimage_concatenated(cimage_concatenated, patch_size):
222
+ # Use the mask from input image (4th channel)
223
+ mask_normalized = cimage_concatenated[:, 3]
224
+ mask_unnormalized = (mask_normalized + 1) / 2
225
+ # (1, 729)
226
+ token_mask_bias = get_token_mask_bias(mask_unnormalized, patch_size=patch_size)
227
+
228
+ # attn_mask: (B, 1, Q, KV)
229
+ # print("Token positions:", token_mask.nonzero())
230
+
231
+ # Obtain token mask in the bias format: in mask 0, out of mask -inf
232
+ q_kv = token_mask_bias.shape[-1]
233
+ attn_mask_bias = token_mask_bias[:, None, None, :].repeat(1, 1, q_kv, 1)
234
+
235
+ # Visualizations
236
+ # print(f"token_mask_bias shape: {token_mask_bias.shape}, attn_mask_bias shape: {attn_mask_bias.shape}")
237
+ # import matplotlib.pyplot as plt
238
+ # plt.imshow(attn_mask_bias[0, 0, 0].view(27, 27).detach().cpu().numpy())
239
+ # plt.title("Attention mask (outside)")
240
+ # plt.show()
241
+
242
+ return attn_mask_bias
243
+
244
+ # From SiglipEncoderLayer. We would like to modify this to cross-attention.
245
+ class CrossAttnEncoderLayer(nn.Module):
246
+ def __init__(self, config: ContextProviderConfig):
247
+ super().__init__()
248
+ self.embed_dim = config.hidden_size
249
+ self.cross_attn = ContextProviderCrossAttention(config)
250
+ self.residual_dropout = nn.Dropout(config.residual_dropout)
251
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
252
+ self.mlp = ContextProviderMLP(config)
253
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
254
+
255
+ if config.zero_init_output:
256
+ # TODO: alternatively, we could parameterize with an MLP
257
+ # These factors are initialized with 0 (so only residual passes through)
258
+ if config.context_provider_type != "cross_attn_at_the_end":
259
+ self.register_parameter("attn_factor", nn.Parameter(torch.zeros((1,))))
260
+ self.register_parameter("mlp_factor", nn.Parameter(torch.zeros((1,))))
261
+ else:
262
+ # Use scalar tensor for compatibility
263
+ self.register_parameter("attn_factor", nn.Parameter(torch.zeros((1,)).view(())))
264
+ self.register_parameter("mlp_factor", nn.Parameter(torch.zeros((1,)).view(())))
265
+ else:
266
+ self.attn_factor = 1.
267
+ self.mlp_factor = 1.
268
+
269
+ # Ignore copy
270
+ def forward(
271
+ self,
272
+ hidden_states: torch.Tensor,
273
+ encoder_hidden_states: torch.Tensor,
274
+ attention_mask: torch.Tensor,
275
+ output_attentions: Optional[bool] = False,
276
+ ) -> Tuple[torch.FloatTensor]:
277
+ """
278
+ Args:
279
+ hidden_states (`torch.FloatTensor`):
280
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
281
+ attention_mask (`torch.FloatTensor`):
282
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
283
+ output_attentions (`bool`, *optional*, defaults to `False`):
284
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
285
+ returned tensors for more detail.
286
+ """
287
+ residual = hidden_states
288
+
289
+ hidden_states = self.layer_norm1(hidden_states)
290
+ hidden_states, attn_weights = self.cross_attn(
291
+ hidden_states=hidden_states,
292
+ encoder_hidden_states=encoder_hidden_states,
293
+ attention_mask=attention_mask,
294
+ output_attentions=output_attentions,
295
+ )
296
+ # Dropping the residual: let the model leverage more on the context
297
+ hidden_states = self.residual_dropout(residual) + self.attn_factor * hidden_states
298
+
299
+ residual = hidden_states
300
+ hidden_states = self.layer_norm2(hidden_states)
301
+ hidden_states = self.mlp(hidden_states)
302
+ hidden_states = residual + self.mlp_factor * hidden_states
303
+
304
+ outputs = (hidden_states,)
305
+
306
+ if output_attentions:
307
+ outputs += (attn_weights,)
308
+
309
+ return outputs
310
+
311
+ class CrossAttnContextProviderEndToAll(nn.Module):
312
+ def __init__(self, config: ContextProviderConfig):
313
+ super().__init__()
314
+ self.layers = nn.ModuleList([
315
+ CrossAttnEncoderLayer(config) for i in enumerate(range(config.num_hidden_layers)) if config.context_provider_layer_indices is None or i in config.context_provider_layer_indices
316
+ ])
317
+ self.patch_size = config.patch_size
318
+ self.masked_cross_attn = config.masked_cross_attn
319
+
320
+ def forward(self, context_image_features, cimage_concatenated, vision_tower):
321
+ # Use the mask from input image (4th channel)
322
+ if self.masked_cross_attn:
323
+ attn_mask = attn_mask_from_cimage_concatenated(cimage_concatenated, patch_size=self.patch_size)
324
+ else:
325
+ attn_mask = None
326
+
327
+ detail_raw_image = cimage_concatenated[:, 4:, ...]
328
+ # NOTE: when using context image as queries, the context image was swapped with the detail image before passing into the context provider
329
+ outputs = vision_tower(detail_raw_image, context_provider_layers=self.layers, contexts=context_image_features, cross_attention_mask=attn_mask)
330
+
331
+ return outputs
332
+
333
+ class ContextProvider(PreTrainedModel):
334
+ config_class = ContextProviderConfig
335
+
336
+ def __init__(
337
+ self, context_provider_cfg: ContextProviderConfig, config: PretrainedConfig
338
+ ):
339
+ super().__init__(context_provider_cfg)
340
+
341
+ self.context_image_as_queries = context_provider_cfg.context_image_as_queries
342
+ self.context_provider_type = context_provider_type = context_provider_cfg.context_provider_type
343
+
344
+ self.treat_image_as_cimage = context_provider_cfg.treat_image_as_cimage
345
+
346
+ if self.context_image_as_queries:
347
+ assert not context_provider_cfg.masked_cross_attn, "Masked cross-attention not implemented when using context image as queries."
348
+ assert "concat" not in context_provider_type, "Concat not implemented when using context image as queries."
349
+
350
+ if context_provider_type == "cross_attn_end_to_all":
351
+ # Information flow: end of context features -> all detail features
352
+ self.context_provider_module = CrossAttnContextProviderEndToAll(context_provider_cfg)
353
+ else:
354
+ raise ValueError(f"Unknown context provider type: {context_provider_type}")
355
+
356
+ def forward(self, cimage_full_features=None, cimage_crop_features=None, cimage_concatenated=None, vision_tower=None):
357
+ if self.context_provider_type == "cross_attn_end_to_all":
358
+ assert cimage_full_features.shape[0] == cimage_concatenated.shape[0], f"shape mismatches: {cimage_full_features.shape[0]} != {cimage_concatenated.shape[0]}"
359
+ return self.context_provider_module(context_image_features=cimage_full_features, cimage_concatenated=cimage_concatenated, vision_tower=vision_tower)
360
+ else:
361
+ raise ValueError(f"Unknown context provider type: {context_provider_type}")
362
+
363
+ AutoConfig.register("context_provider", ContextProviderConfig)
364
+ AutoModel.register(ContextProviderConfig, ContextProvider)
dam/model/multimodal_encoder/image_processor.py ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Image processor class for RADIO."""
17
+ import math
18
+ from typing import Dict, List, Optional, Tuple, Union
19
+
20
+ import numpy as np
21
+
22
+ from PIL import Image
23
+
24
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
25
+ from transformers.image_transforms import convert_to_rgb, pad, resize, to_channel_dimension_format
26
+ from transformers.image_utils import (
27
+ IMAGENET_DEFAULT_MEAN,
28
+ IMAGENET_DEFAULT_STD,
29
+ ChannelDimension,
30
+ ImageInput,
31
+ PILImageResampling,
32
+ get_image_size,
33
+ infer_channel_dimension_format,
34
+ is_scaled_image,
35
+ make_list_of_images,
36
+ to_numpy_array,
37
+ valid_images,
38
+ )
39
+ from transformers.utils import (
40
+ TensorType,
41
+ is_tf_available,
42
+ is_torch_available,
43
+ is_torchvision_available,
44
+ logging,
45
+ )
46
+
47
+
48
+ if is_torch_available():
49
+ import torch
50
+
51
+ if is_torchvision_available():
52
+ pass
53
+
54
+ if is_tf_available():
55
+
56
+ pass
57
+
58
+ logger = logging.get_logger(__name__)
59
+
60
+
61
+ def rank_print(s):
62
+ rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
63
+ print(f"[Rank {rank}] {s}")
64
+
65
+ class ImageProcessor(BaseImageProcessor):
66
+ r"""
67
+ Constructs an image processor.
68
+
69
+ Args:
70
+ do_resize (`bool`, *optional*, defaults to `True`):
71
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
72
+ `do_resize` parameter in the `preprocess` method.
73
+ size (`dict`, *optional*, defaults to `{"longest_edge": 1024}`):
74
+ Size of the output image after resizing. If "longest_edge" is specified, resizes the longest edge of the image to match
75
+ `size["longest_edge"]` while maintaining the aspect ratio. If "width" and "height" are specified, resizes the image
76
+ to that size, possibly changing the aspect ratio. Can be overridden by the `size` parameter in the
77
+ `preprocess` method.
78
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
79
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
80
+ `preprocess` method.
81
+ do_rescale (`bool`, *optional*, defaults to `True`):
82
+ Wwhether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
83
+ `do_rescale` parameter in the `preprocess` method.
84
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
85
+ Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
86
+ overridden by the `rescale_factor` parameter in the `preprocess` method.
87
+ do_normalize (`bool`, *optional*, defaults to `True`):
88
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
89
+ method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
90
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`):
91
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
92
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
93
+ overridden by the `image_mean` parameter in the `preprocess` method.
94
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`):
95
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
96
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
97
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
98
+ do_pad (`bool`, *optional*, defaults to `True`):
99
+ Whether to pad the image to the specified `pad_size`. Can be overridden by the `do_pad` parameter in the
100
+ `preprocess` method.
101
+ pad_size (`dict`, *optional*, defaults to `{"height": 1024, "width": 1024}`):
102
+ Size of the output image after padding. Can be overridden by the `pad_size` parameter in the `preprocess`
103
+ method.
104
+ pad_value (`float` or `Iterable[float]`, *optional*, defaults to `0.`):
105
+ Value of padded pixels.
106
+ pad_multiple (`int`, *optional*, defaults to `None`):
107
+ Pad to a multiple of specified number.
108
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
109
+ Whether to convert the image to RGB.
110
+ """
111
+
112
+ model_input_names = ["pixel_values"]
113
+
114
+ def __init__(
115
+ self,
116
+ do_resize: bool = True,
117
+ size: Dict[str, int] = None,
118
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
119
+ do_rescale: bool = True,
120
+ rescale_factor: Union[int, float] = 1 / 255,
121
+ do_normalize: bool = True,
122
+ image_mean: Optional[Union[float, List[float]]] = None,
123
+ image_std: Optional[Union[float, List[float]]] = None,
124
+ do_pad: bool = True,
125
+ pad_size: int = None,
126
+ pad_multiple: int = None,
127
+ pad_value: Optional[Union[float, List[float]]] = 0.,
128
+ do_convert_rgb: bool = True,
129
+ **kwargs,
130
+ ) -> None:
131
+ super().__init__(**kwargs)
132
+ x = 0
133
+ size = size if size is not None else {"longest_edge": 1024}
134
+ size = get_size_dict(max_size=size, default_to_square=False) if not isinstance(size, dict) else size
135
+
136
+ if pad_size is not None and pad_multiple is not None:
137
+ raise ValueError("pad_size and pad_multiple should not be set at the same time.")
138
+
139
+ pad_size = pad_size if pad_size is not None else {"height": 1024, "width": 1024} if pad_multiple is not None else None
140
+ if do_pad:
141
+ pad_size = get_size_dict(pad_size, default_to_square=True)
142
+
143
+ self.do_resize = do_resize
144
+ self.size = size
145
+ self.resample = resample
146
+ self.do_rescale = do_rescale
147
+ self.rescale_factor = rescale_factor
148
+ self.do_normalize = do_normalize
149
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
150
+ self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
151
+ self.do_pad = do_pad
152
+ self.pad_multiple = pad_multiple
153
+ self.pad_size = pad_size
154
+ self.pad_value = tuple(pad_value) if isinstance(pad_value, list) else pad_value
155
+ self.do_convert_rgb = do_convert_rgb
156
+ self._valid_processor_keys = [
157
+ "images",
158
+ "segmentation_maps",
159
+ "do_resize",
160
+ "size",
161
+ "resample",
162
+ "do_rescale",
163
+ "rescale_factor",
164
+ "do_normalize",
165
+ "image_mean",
166
+ "image_std",
167
+ "do_pad",
168
+ "pad_size",
169
+ "do_convert_rgb",
170
+ "return_tensors",
171
+ "data_format",
172
+ "input_data_format",
173
+ ]
174
+
175
+ def pad_image(
176
+ self,
177
+ image: np.ndarray,
178
+ pad_size: Dict[str, int],
179
+ data_format: Optional[Union[str, ChannelDimension]] = None,
180
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
181
+ **kwargs,
182
+ ) -> np.ndarray:
183
+ """
184
+ Pad an image to `(pad_size["height"], pad_size["width"])` to the right and bottom.
185
+
186
+ Args:
187
+ image (`np.ndarray`):
188
+ Image to pad.
189
+ pad_size (`Dict[str, int]`):
190
+ Size of the output image after padding.
191
+ data_format (`str` or `ChannelDimension`, *optional*):
192
+ The data format of the image. Can be either "channels_first" or "channels_last". If `None`, the
193
+ `data_format` of the `image` will be used.
194
+ input_data_format (`str` or `ChannelDimension`, *optional*):
195
+ The channel dimension format of the input image. If not provided, it will be inferred.
196
+ """
197
+ output_height, output_width = pad_size["height"], pad_size["width"]
198
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
199
+
200
+ pad_width = output_width - input_width
201
+ pad_height = output_height - input_height
202
+
203
+ padded_image = pad(
204
+ image,
205
+ ((0, pad_height), (0, pad_width)),
206
+ data_format=data_format,
207
+ input_data_format=input_data_format,
208
+ constant_values=self.pad_value,
209
+ **kwargs,
210
+ )
211
+ return padded_image
212
+
213
+ def _get_preprocess_shape(self, old_shape: Tuple[int, int], longest_edge: int):
214
+ """
215
+ Compute the output size given input size and target long side length.
216
+ """
217
+ oldh, oldw = old_shape
218
+ scale = longest_edge * 1.0 / max(oldh, oldw)
219
+ newh, neww = oldh * scale, oldw * scale
220
+ newh = int(newh + 0.5)
221
+ neww = int(neww + 0.5)
222
+ return (newh, neww)
223
+
224
+ def resize(
225
+ self,
226
+ image: np.ndarray,
227
+ size: Dict[str, int],
228
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
229
+ data_format: Optional[Union[str, ChannelDimension]] = None,
230
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
231
+ **kwargs,
232
+ ) -> np.ndarray:
233
+ """
234
+ Resize an image to `(size["height"], size["width"])`.
235
+
236
+ Args:
237
+ image (`np.ndarray`):
238
+ Image to resize.
239
+ size (`Dict[str, int]`):
240
+ Dictionary in the format `{"longest_edge": int}` or `{"width": int, "height": int}` specifying the size
241
+ of the output image. If "longest_edge" is specified, resizes the longest edge of the image to match
242
+ `size["longest_edge"]` while maintaining the aspect ratio. If "width" and "height" are specified, resizes the image
243
+ to that size, possibly changing the aspect ratio.
244
+ resample:
245
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
246
+ data_format (`ChannelDimension` or `str`, *optional*):
247
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
248
+ image is used. Can be one of:
249
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
250
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
251
+ input_data_format (`ChannelDimension` or `str`, *optional*):
252
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
253
+ from the input image. Can be one of:
254
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
255
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
256
+
257
+ Returns:
258
+ `np.ndarray`: The resized image.
259
+ """
260
+ size = get_size_dict(size)
261
+ if "longest_edge" not in size:
262
+ if "width" not in size or "height" not in size:
263
+ raise ValueError(f"The `size` dictionary must contain the key `longest_edge`, or `width` and `height`. Got {size.keys()}")
264
+ input_size = get_image_size(image, channel_dim=input_data_format)
265
+ if "longest_edge" in size:
266
+ output_height, output_width = self._get_preprocess_shape(input_size, size["longest_edge"])
267
+ else:
268
+ output_height, output_width = size["height"], size["width"]
269
+ return resize(
270
+ image,
271
+ size=(output_height, output_width),
272
+ resample=resample,
273
+ data_format=data_format,
274
+ input_data_format=input_data_format,
275
+ **kwargs,
276
+ )
277
+
278
+ def _preprocess(
279
+ self,
280
+ image: ImageInput,
281
+ do_resize: bool,
282
+ do_rescale: bool,
283
+ do_normalize: bool,
284
+ size: Optional[Dict[str, int]] = None,
285
+ resample: PILImageResampling = None,
286
+ rescale_factor: Optional[float] = None,
287
+ image_mean: Optional[Union[float, List[float]]] = None,
288
+ image_std: Optional[Union[float, List[float]]] = None,
289
+ do_pad: Optional[bool] = None,
290
+ pad_size: Optional[Dict[str, int]] = None,
291
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
292
+ ):
293
+ if do_resize:
294
+ image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
295
+ reshaped_input_size = get_image_size(image, channel_dim=input_data_format)
296
+
297
+ if do_rescale:
298
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
299
+
300
+ if do_normalize:
301
+ image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
302
+
303
+ if do_pad:
304
+ if self.pad_multiple:
305
+ h, w = get_image_size(image, channel_dim=input_data_format)
306
+ pad_size = {
307
+ "height": math.ceil(h / self.pad_multiple) * self.pad_multiple,
308
+ "width": math.ceil(w / self.pad_multiple) * self.pad_multiple,
309
+ }
310
+
311
+ image = self.pad_image(image=image, pad_size=pad_size, input_data_format=input_data_format)
312
+
313
+ return image, reshaped_input_size
314
+
315
+ def _preprocess_image(
316
+ self,
317
+ image: ImageInput,
318
+ do_resize: Optional[bool] = None,
319
+ size: Dict[str, int] = None,
320
+ resample: PILImageResampling = None,
321
+ do_rescale: bool = None,
322
+ rescale_factor: Optional[float] = None,
323
+ do_normalize: Optional[bool] = None,
324
+ image_mean: Optional[Union[float, List[float]]] = None,
325
+ image_std: Optional[Union[float, List[float]]] = None,
326
+ do_pad: Optional[bool] = None,
327
+ pad_size: Optional[Dict[str, int]] = None,
328
+ do_convert_rgb: Optional[bool] = None,
329
+ data_format: Optional[Union[str, ChannelDimension]] = None,
330
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
331
+ ) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]]:
332
+ #image = to_numpy_array(image)
333
+
334
+ # import time
335
+ # if int(time.time()*1000) % 10 == 0:
336
+ # # create an PIL image of size 1x1
337
+ # image = PIL.Image.new('RGB', (1, 1))
338
+
339
+ if isinstance(image, Image.Image):
340
+ # PIL always uses Channels Last.
341
+ input_data_format = ChannelDimension.LAST
342
+
343
+ # PIL RGBA images are converted to RGB
344
+ #mode_before = image.mode
345
+ if do_convert_rgb:
346
+ image = convert_to_rgb(image)
347
+
348
+ # All transformations expect numpy arrays.
349
+ image_ = image
350
+ image = to_numpy_array(image)
351
+
352
+ # if isinstance(image_, np.ndarray):
353
+ # rank_print(f"preprocess image type={type(image_)} shape={image_.shape} array shape={image.shape}")
354
+ # elif isinstance(image_, Image.Image):
355
+ # rank_print(f"preprocessimage type={type(image_)} size={image_.size} mode={image_.mode} array shape={image.shape}")
356
+ # else:
357
+ # rank_print(f"preprocess unknown image type={type(image_)} array shape={image.shape}")
358
+
359
+ if len(image.shape) == 2:
360
+ h, w = image.shape
361
+ ret = np.empty((h, w, 3), dtype=np.uint8)
362
+ ret[:, :, 0] = image
363
+ ret[:, :, 1] = image
364
+ ret[:, :, 2] = image
365
+ image = ret
366
+ rank_print(f"preprocess new image shape={image.shape}")
367
+ elif len(image.shape) == 3 and image.shape[-1] == 1:
368
+ ret = np.empty((h, w, 3), dtype=np.uint8)
369
+ ret[:, :, 0] = image[:, :, 0]
370
+ ret[:, :, 1] = image[:, :, 0]
371
+ ret[:, :, 2] = image[:, :, 0]
372
+ image = ret
373
+ rank_print(f"preprocess new image shape={image.shape}")
374
+
375
+ if is_scaled_image(image) and do_rescale:
376
+ logger.warning_once(
377
+ "It looks like you are trying to rescale already rescaled images. If the input"
378
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
379
+ )
380
+
381
+ if input_data_format is None:
382
+ input_data_format = infer_channel_dimension_format(image)
383
+
384
+ original_size = get_image_size(image, channel_dim=input_data_format)
385
+
386
+ image, reshaped_input_size = self._preprocess(
387
+ image=image,
388
+ do_resize=do_resize,
389
+ size=size,
390
+ resample=resample,
391
+ do_rescale=do_rescale,
392
+ rescale_factor=rescale_factor,
393
+ do_normalize=do_normalize,
394
+ image_mean=image_mean,
395
+ image_std=image_std,
396
+ do_pad=do_pad,
397
+ pad_size=pad_size,
398
+ input_data_format=input_data_format,
399
+ )
400
+
401
+ if data_format is not None:
402
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
403
+
404
+ # rank_print(f"preprocess original_size={original_size} reshaped_input_size={reshaped_input_size} image shape={image.shape} type={type(image)}")
405
+
406
+ # if image is a single channel convert to rgb
407
+ if do_convert_rgb and image.shape[0] == 1:
408
+ c, h, w = image.shape
409
+ ret = np.empty((3, h, w), dtype=np.uint8)
410
+ ret[0, :, :] = image[0, :, :]
411
+ ret[1, :, :] = image[0, :, :]
412
+ ret[2, :, :] = image[0, :, :]
413
+ image = ret
414
+ rank_print(f"preprocess final: {image.shape}")
415
+
416
+ return image, original_size, reshaped_input_size
417
+
418
+ def preprocess(
419
+ self,
420
+ images: ImageInput,
421
+ do_resize: Optional[bool] = None,
422
+ size: Optional[Dict[str, int]] = None,
423
+ resample: Optional["PILImageResampling"] = None,
424
+ do_rescale: Optional[bool] = None,
425
+ rescale_factor: Optional[Union[int, float]] = None,
426
+ do_normalize: Optional[bool] = None,
427
+ image_mean: Optional[Union[float, List[float]]] = None,
428
+ image_std: Optional[Union[float, List[float]]] = None,
429
+ do_pad: Optional[bool] = None,
430
+ pad_size: Optional[Dict[str, int]] = None,
431
+ do_convert_rgb: Optional[bool] = None,
432
+ return_tensors: Optional[Union[str, TensorType]] = None,
433
+ data_format: ChannelDimension = ChannelDimension.FIRST,
434
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
435
+ **kwargs,
436
+ ):
437
+ """
438
+ Preprocess an image or batch of images.
439
+
440
+ Args:
441
+ images (`ImageInput`):
442
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
443
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
444
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
445
+ Whether to resize the image.
446
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
447
+ Controls the size of the image after `resize`. The longest edge of the image is resized to
448
+ `size["longest_edge"]` whilst preserving the aspect ratio.
449
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
450
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
451
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
452
+ Whether to rescale the image pixel values by rescaling factor.
453
+ rescale_factor (`int` or `float`, *optional*, defaults to `self.rescale_factor`):
454
+ Rescale factor to apply to the image pixel values.
455
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
456
+ Whether to normalize the image.
457
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
458
+ Image mean to normalize the image by if `do_normalize` is set to `True`.
459
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
460
+ Image standard deviation to normalize the image by if `do_normalize` is set to `True`.
461
+ do_pad (`bool`, *optional*, defaults to `self.do_pad`):
462
+ Whether to pad the image.
463
+ pad_size (`Dict[str, int]`, *optional*, defaults to `self.pad_size`):
464
+ Controls the size of the padding applied to the image. The image is padded to `pad_size["height"]` and
465
+ `pad_size["width"]` if `do_pad` is set to `True`.
466
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
467
+ Whether to convert the image to RGB.
468
+ return_tensors (`str` or `TensorType`, *optional*):
469
+ The type of tensors to return. Can be one of:
470
+ - Unset: Return a list of `np.ndarray`.
471
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
472
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
473
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
474
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
475
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
476
+ The channel dimension format for the output image. Can be one of:
477
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
478
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
479
+ - Unset: Use the channel dimension format of the input image.
480
+ input_data_format (`ChannelDimension` or `str`, *optional*):
481
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
482
+ from the input image. Can be one of:
483
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
484
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
485
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
486
+ """
487
+ do_resize = do_resize if do_resize is not None else self.do_resize
488
+ size = size if size is not None else self.size
489
+ size = get_size_dict(max_size=size, default_to_square=False) if not isinstance(size, dict) else size
490
+ resample = resample if resample is not None else self.resample
491
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
492
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
493
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
494
+ image_mean = image_mean if image_mean is not None else self.image_mean
495
+ image_std = image_std if image_std is not None else self.image_std
496
+ do_pad = do_pad if do_pad is not None else self.do_pad
497
+ pad_size = pad_size if pad_size is not None else self.pad_size
498
+ if do_pad:
499
+ pad_size = get_size_dict(pad_size, default_to_square=True)
500
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
501
+
502
+ images = make_list_of_images(images)
503
+
504
+ if not valid_images(images):
505
+ raise ValueError(
506
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
507
+ "torch.Tensor, tf.Tensor or jax.ndarray."
508
+ )
509
+
510
+ images, original_sizes, reshaped_input_sizes = zip(
511
+ *(
512
+ self._preprocess_image(
513
+ image=img,
514
+ do_resize=do_resize,
515
+ size=size,
516
+ resample=resample,
517
+ do_rescale=do_rescale,
518
+ rescale_factor=rescale_factor,
519
+ do_normalize=do_normalize,
520
+ image_mean=image_mean,
521
+ image_std=image_std,
522
+ do_pad=do_pad,
523
+ pad_size=pad_size,
524
+ do_convert_rgb=do_convert_rgb,
525
+ data_format=data_format,
526
+ input_data_format=input_data_format,
527
+ )
528
+ for img in images
529
+ )
530
+ )
531
+
532
+ data = {
533
+ "pixel_values": images,
534
+ "original_sizes": original_sizes,
535
+ "reshaped_input_sizes": reshaped_input_sizes,
536
+ }
537
+
538
+ return BatchFeature(data=data, tensor_type=return_tensors)
dam/model/multimodal_encoder/siglip/__init__.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from transformers.utils import (
17
+ OptionalDependencyNotAvailable,
18
+ _LazyModule,
19
+ is_torch_available,
20
+ is_vision_available,
21
+ )
22
+
23
+
24
+ _import_structure = {
25
+ "configuration_siglip": [
26
+ "SIGLIP_PRETRAINED_CONFIG_ARCHIVE_MAP",
27
+ "SiglipConfig",
28
+ "SiglipTextConfig",
29
+ "SiglipVisionConfig",
30
+ ],
31
+ "processing_siglip": ["SiglipProcessor"],
32
+ "tokenization_siglip": ["SiglipTokenizer"],
33
+ }
34
+
35
+ try:
36
+ if not is_vision_available():
37
+ raise OptionalDependencyNotAvailable()
38
+ except OptionalDependencyNotAvailable:
39
+ pass
40
+ else:
41
+ _import_structure["image_processing_siglip"] = ["SiglipImageProcessor"]
42
+
43
+ try:
44
+ if not is_torch_available():
45
+ raise OptionalDependencyNotAvailable()
46
+ except OptionalDependencyNotAvailable:
47
+ pass
48
+ else:
49
+ _import_structure["modeling_siglip"] = [
50
+ "SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST",
51
+ "SiglipModel",
52
+ "SiglipPreTrainedModel",
53
+ "SiglipTextModel",
54
+ "SiglipVisionModel",
55
+ ]
56
+
57
+
58
+ if TYPE_CHECKING:
59
+ from .configuration_siglip import (
60
+ SIGLIP_PRETRAINED_CONFIG_ARCHIVE_MAP,
61
+ SiglipConfig,
62
+ SiglipTextConfig,
63
+ SiglipVisionConfig,
64
+ )
65
+ from .processing_siglip import SiglipProcessor
66
+ from .tokenization_siglip import SiglipTokenizer
67
+
68
+ try:
69
+ if not is_vision_available():
70
+ raise OptionalDependencyNotAvailable()
71
+ except OptionalDependencyNotAvailable:
72
+ pass
73
+ else:
74
+ from .image_processing_siglip import SiglipImageProcessor
75
+
76
+ try:
77
+ if not is_torch_available():
78
+ raise OptionalDependencyNotAvailable()
79
+ except OptionalDependencyNotAvailable:
80
+ pass
81
+ else:
82
+ from .modeling_siglip import (
83
+ SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST,
84
+ SiglipModel,
85
+ SiglipPreTrainedModel,
86
+ SiglipTextModel,
87
+ SiglipVisionModel,
88
+ )
89
+
90
+
91
+ else:
92
+ import sys
93
+
94
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
dam/model/multimodal_encoder/siglip/configuration_siglip.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Siglip model configuration"""
16
+
17
+ import os
18
+ from typing import Union
19
+
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.utils import logging
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ SIGLIP_PRETRAINED_CONFIG_ARCHIVE_MAP = {
27
+ "google/siglip-base-patch16-224": "https://huggingface.co/google/siglip-base-patch16-224/resolve/main/config.json",
28
+ }
29
+
30
+
31
+ class SiglipTextConfig(PretrainedConfig):
32
+ r"""
33
+ This is the configuration class to store the configuration of a [`SiglipTextModel`]. It is used to instantiate a
34
+ Siglip text encoder according to the specified arguments, defining the model architecture. Instantiating a
35
+ configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip
36
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
37
+
38
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
39
+ documentation from [`PretrainedConfig`] for more information.
40
+
41
+ Args:
42
+ vocab_size (`int`, *optional*, defaults to 32000):
43
+ Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by
44
+ the `inputs_ids` passed when calling [`SiglipModel`].
45
+ hidden_size (`int`, *optional*, defaults to 768):
46
+ Dimensionality of the encoder layers and the pooler layer.
47
+ intermediate_size (`int`, *optional*, defaults to 3072):
48
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
49
+ num_hidden_layers (`int`, *optional*, defaults to 12):
50
+ Number of hidden layers in the Transformer encoder.
51
+ num_attention_heads (`int`, *optional*, defaults to 12):
52
+ Number of attention heads for each attention layer in the Transformer encoder.
53
+ max_position_embeddings (`int`, *optional*, defaults to 64):
54
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
55
+ just in case (e.g., 512 or 1024 or 2048).
56
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
57
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
58
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
59
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
60
+ The epsilon used by the layer normalization layers.
61
+ attention_dropout (`float`, *optional*, defaults to 0.0):
62
+ The dropout ratio for the attention probabilities.
63
+ pad_token_id (`int`, *optional*, defaults to 1):
64
+ The id of the padding token in the vocabulary.
65
+ bos_token_id (`int`, *optional*, defaults to 49406):
66
+ The id of the beginning-of-sequence token in the vocabulary.
67
+ eos_token_id (`int`, *optional*, defaults to 49407):
68
+ The id of the end-of-sequence token in the vocabulary.
69
+
70
+ Example:
71
+
72
+ ```python
73
+ >>> from transformers import SiglipTextConfig, SiglipTextModel
74
+
75
+ >>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration
76
+ >>> configuration = SiglipTextConfig()
77
+
78
+ >>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration
79
+ >>> model = SiglipTextModel(configuration)
80
+
81
+ >>> # Accessing the model configuration
82
+ >>> configuration = model.config
83
+ ```"""
84
+
85
+ model_type = "siglip_text_model"
86
+
87
+ def __init__(
88
+ self,
89
+ vocab_size=32000,
90
+ hidden_size=768,
91
+ intermediate_size=3072,
92
+ num_hidden_layers=12,
93
+ num_attention_heads=12,
94
+ max_position_embeddings=64,
95
+ hidden_act="gelu_pytorch_tanh",
96
+ layer_norm_eps=1e-6,
97
+ attention_dropout=0.0,
98
+ # This differs from `CLIPTokenizer`'s default and from openai/siglip
99
+ # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
100
+ pad_token_id=1,
101
+ bos_token_id=49406,
102
+ eos_token_id=49407,
103
+ **kwargs,
104
+ ):
105
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
106
+
107
+ self.vocab_size = vocab_size
108
+ self.hidden_size = hidden_size
109
+ self.intermediate_size = intermediate_size
110
+ self.num_hidden_layers = num_hidden_layers
111
+ self.num_attention_heads = num_attention_heads
112
+ self.max_position_embeddings = max_position_embeddings
113
+ self.layer_norm_eps = layer_norm_eps
114
+ self.hidden_act = hidden_act
115
+ self.attention_dropout = attention_dropout
116
+
117
+ @classmethod
118
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
119
+ # cls._set_token_in_kwargs(kwargs)
120
+
121
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
122
+
123
+ # get the text config dict if we are loading from SiglipConfig
124
+ if config_dict.get("model_type") == "siglip":
125
+ config_dict = config_dict["text_config"]
126
+
127
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
128
+ logger.warning(
129
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
130
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
131
+ )
132
+
133
+ return cls.from_dict(config_dict, **kwargs)
134
+
135
+
136
+ class SiglipVisionConfig(PretrainedConfig):
137
+ r"""
138
+ This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
139
+ Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
140
+ configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
141
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
142
+
143
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
144
+ documentation from [`PretrainedConfig`] for more information.
145
+
146
+ Args:
147
+ hidden_size (`int`, *optional*, defaults to 768):
148
+ Dimensionality of the encoder layers and the pooler layer.
149
+ intermediate_size (`int`, *optional*, defaults to 3072):
150
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
151
+ num_hidden_layers (`int`, *optional*, defaults to 12):
152
+ Number of hidden layers in the Transformer encoder.
153
+ num_attention_heads (`int`, *optional*, defaults to 12):
154
+ Number of attention heads for each attention layer in the Transformer encoder.
155
+ num_channels (`int`, *optional*, defaults to 3):
156
+ Number of channels in the input images.
157
+ image_size (`int`, *optional*, defaults to 224):
158
+ The size (resolution) of each image.
159
+ patch_size (`int`, *optional*, defaults to 16):
160
+ The size (resolution) of each patch.
161
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
162
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
163
+ `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
164
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
165
+ The epsilon used by the layer normalization layers.
166
+ attention_dropout (`float`, *optional*, defaults to 0.0):
167
+ The dropout ratio for the attention probabilities.
168
+ num_mask_channels (`int`, *optional*, defaults to 0):
169
+ Number of mask channels in the input images.
170
+
171
+ Example:
172
+
173
+ ```python
174
+ >>> from transformers import SiglipVisionConfig, SiglipVisionModel
175
+
176
+ >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
177
+ >>> configuration = SiglipVisionConfig()
178
+
179
+ >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
180
+ >>> model = SiglipVisionModel(configuration)
181
+
182
+ >>> # Accessing the model configuration
183
+ >>> configuration = model.config
184
+ ```"""
185
+
186
+ model_type = "siglip_vision_model"
187
+
188
+ def __init__(
189
+ self,
190
+ hidden_size=768,
191
+ intermediate_size=3072,
192
+ num_hidden_layers=12,
193
+ num_attention_heads=12,
194
+ num_channels=3,
195
+ image_size=224,
196
+ patch_size=16,
197
+ hidden_act="gelu_pytorch_tanh",
198
+ layer_norm_eps=1e-6,
199
+ attention_dropout=0.0,
200
+ num_mask_channels=0,
201
+ **kwargs,
202
+ ):
203
+ super().__init__(**kwargs)
204
+
205
+ self.hidden_size = hidden_size
206
+ self.intermediate_size = intermediate_size
207
+ self.num_hidden_layers = num_hidden_layers
208
+ self.num_attention_heads = num_attention_heads
209
+ self.num_channels = num_channels
210
+ self.patch_size = patch_size
211
+ self.image_size = image_size
212
+ self.attention_dropout = attention_dropout
213
+ self.layer_norm_eps = layer_norm_eps
214
+ self.hidden_act = hidden_act
215
+ self.num_mask_channels = num_mask_channels
216
+
217
+ @classmethod
218
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
219
+ # cls._set_token_in_kwargs(kwargs)
220
+
221
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
222
+
223
+ # get the vision config dict if we are loading from SiglipConfig
224
+ if config_dict.get("model_type") == "siglip":
225
+ config_dict = config_dict["vision_config"]
226
+
227
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
228
+ logger.warning(
229
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
230
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
231
+ )
232
+
233
+ return cls.from_dict(config_dict, **kwargs)
234
+
235
+
236
+ class SiglipConfig(PretrainedConfig):
237
+ r"""
238
+ [`SiglipConfig`] is the configuration class to store the configuration of a [`SiglipModel`]. It is used to
239
+ instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs.
240
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip
241
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
242
+
243
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
244
+ documentation from [`PretrainedConfig`] for more information.
245
+
246
+ Args:
247
+ text_config (`dict`, *optional*):
248
+ Dictionary of configuration options used to initialize [`SiglipTextConfig`].
249
+ vision_config (`dict`, *optional*):
250
+ Dictionary of configuration options used to initialize [`SiglipVisionConfig`].
251
+ kwargs (*optional*):
252
+ Dictionary of keyword arguments.
253
+
254
+ Example:
255
+
256
+ ```python
257
+ >>> from transformers import SiglipConfig, SiglipModel
258
+
259
+ >>> # Initializing a SiglipConfig with google/siglip-base-patch16-224 style configuration
260
+ >>> configuration = SiglipConfig()
261
+
262
+ >>> # Initializing a SiglipModel (with random weights) from the google/siglip-base-patch16-224 style configuration
263
+ >>> model = SiglipModel(configuration)
264
+
265
+ >>> # Accessing the model configuration
266
+ >>> configuration = model.config
267
+
268
+ >>> # We can also initialize a SiglipConfig from a SiglipTextConfig and a SiglipVisionConfig
269
+ >>> from transformers import SiglipTextConfig, SiglipVisionConfig
270
+
271
+ >>> # Initializing a SiglipText and SiglipVision configuration
272
+ >>> config_text = SiglipTextConfig()
273
+ >>> config_vision = SiglipVisionConfig()
274
+
275
+ >>> config = SiglipConfig.from_text_vision_configs(config_text, config_vision)
276
+ ```"""
277
+
278
+ model_type = "siglip"
279
+
280
+ def __init__(self, text_config=None, vision_config=None, **kwargs):
281
+ super().__init__(**kwargs)
282
+
283
+ if text_config is None:
284
+ text_config = {}
285
+ logger.info("`text_config` is `None`. Initializing the `SiglipTextConfig` with default values.")
286
+
287
+ if vision_config is None:
288
+ vision_config = {}
289
+ logger.info("`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values.")
290
+
291
+ self.text_config = SiglipTextConfig(**text_config)
292
+ self.vision_config = SiglipVisionConfig(**vision_config)
293
+
294
+ self.initializer_factor = 1.0
295
+
296
+ @classmethod
297
+ def from_text_vision_configs(cls, text_config: SiglipTextConfig, vision_config: SiglipVisionConfig, **kwargs):
298
+ r"""
299
+ Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision
300
+ model configuration.
301
+
302
+ Returns:
303
+ [`SiglipConfig`]: An instance of a configuration object
304
+ """
305
+
306
+ return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
dam/model/multimodal_encoder/siglip/image_processing_siglip.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for SigLIP."""
16
+
17
+ from typing import Dict, List, Optional, Union
18
+
19
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
20
+ from transformers.image_transforms import (
21
+ resize,
22
+ rescale,
23
+ normalize,
24
+ to_channel_dimension_format,
25
+ get_resize_output_image_size,
26
+ get_channel_dimension_axis,
27
+ convert_to_rgb,
28
+ )
29
+ from transformers.image_utils import (
30
+ IMAGENET_STANDARD_MEAN,
31
+ IMAGENET_STANDARD_STD,
32
+ ChannelDimension,
33
+ ImageInput,
34
+ PILImageResampling,
35
+ make_list_of_images,
36
+ to_numpy_array,
37
+ valid_images,
38
+ )
39
+ from transformers.utils import TensorType, is_vision_available, logging
40
+ import numpy as np
41
+
42
+ logger = logging.get_logger(__name__)
43
+
44
+ def is_scaled_image(image: np.ndarray) -> bool:
45
+ """
46
+ Checks to see whether the pixel values have already been rescaled to [0, 1].
47
+ """
48
+ if image.dtype == np.uint8:
49
+ return False
50
+
51
+ # It's possible the image has pixel values in [0, 255] but is of floating type
52
+ return np.min(image) >= 0 and np.max(image) <= 1
53
+
54
+
55
+ if is_vision_available():
56
+ import PIL
57
+
58
+
59
+ class SiglipImageProcessor(BaseImageProcessor):
60
+ r"""
61
+ Constructs a SigLIP image processor.
62
+
63
+ Args:
64
+ do_resize (`bool`, *optional*, defaults to `True`):
65
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
66
+ `do_resize` in the `preprocess` method.
67
+ size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
68
+ Size of the image after resizing. Can be overridden by `size` in the `preprocess` method.
69
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
70
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
71
+ do_rescale (`bool`, *optional*, defaults to `True`):
72
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
73
+ the `preprocess` method.
74
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
75
+ Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
76
+ method.
77
+ do_normalize (`bool`, *optional*, defaults to `True`):
78
+ Whether to normalize the image by the specified mean and standard deviation. Can be overridden by
79
+ `do_normalize` in the `preprocess` method.
80
+ image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
81
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
82
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
83
+ image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
84
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
85
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
86
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
87
+ """
88
+
89
+ model_input_names = ["pixel_values"]
90
+
91
+ def __init__(
92
+ self,
93
+ do_resize: bool = True,
94
+ size: Dict[str, int] = None,
95
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
96
+ do_rescale: bool = True,
97
+ rescale_factor: Union[int, float] = 1 / 255,
98
+ do_normalize: bool = True,
99
+ image_mean: Optional[Union[float, List[float]]] = None,
100
+ image_std: Optional[Union[float, List[float]]] = None,
101
+ do_convert_rgb: bool = True,
102
+ **kwargs,
103
+ ) -> None:
104
+ super().__init__(**kwargs)
105
+ size = size if size is not None else {"shortest_edge": 384}
106
+ size = get_size_dict(size, default_to_square=False)
107
+ image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
108
+ image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
109
+
110
+ self.do_resize = do_resize
111
+ self.size = size
112
+ self.resample = resample
113
+ self.do_rescale = do_rescale
114
+ self.rescale_factor = rescale_factor
115
+ self.do_normalize = do_normalize
116
+ self.image_mean = image_mean
117
+ self.image_std = image_std
118
+ self.do_convert_rgb = do_convert_rgb
119
+
120
+ def resize(
121
+ self,
122
+ image: np.ndarray,
123
+ size: Dict[str, int],
124
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
125
+ data_format: Optional[Union[str, ChannelDimension]] = None,
126
+ **kwargs,
127
+ ) -> np.ndarray:
128
+ """
129
+ Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
130
+ resized to keep the input aspect ratio.
131
+
132
+ Args:
133
+ image (`np.ndarray`):
134
+ Image to resize.
135
+ size (`Dict[str, int]`):
136
+ Size of the output image.
137
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
138
+ Resampling filter to use when resiizing the image.
139
+ data_format (`str` or `ChannelDimension`, *optional*):
140
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
141
+ """
142
+ # size = get_size_dict(size, default_to_square=False)
143
+ default_to_square = True
144
+ if "shortest_edge" in size:
145
+ size = size["shortest_edge"]
146
+ default_to_square = False
147
+ elif "height" in size and "width" in size:
148
+ size = (size["height"], size["width"])
149
+ else:
150
+ raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
151
+ output_size = get_resize_output_image_size(image, size=size, default_to_square=default_to_square)
152
+ return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
153
+
154
+ def preprocess(
155
+ self,
156
+ images: ImageInput,
157
+ do_resize: bool = None,
158
+ size: Dict[str, int] = None,
159
+ resample: PILImageResampling = None,
160
+ do_rescale: bool = None,
161
+ rescale_factor: float = None,
162
+ do_normalize: bool = None,
163
+ image_mean: Optional[Union[float, List[float]]] = None,
164
+ image_std: Optional[Union[float, List[float]]] = None,
165
+ do_convert_rgb: bool = None,
166
+ return_tensors: Optional[Union[str, TensorType]] = None,
167
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
168
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
169
+ **kwargs,
170
+ ) -> PIL.Image.Image:
171
+ """
172
+ Preprocess an image or batch of images.
173
+
174
+ Args:
175
+ images (`ImageInput`):
176
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
177
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
178
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
179
+ Whether to resize the image.
180
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
181
+ Size of the image after resizing.
182
+ resample (`int`, *optional*, defaults to `self.resample`):
183
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
184
+ has an effect if `do_resize` is set to `True`.
185
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
186
+ Whether to rescale the image.
187
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
188
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
189
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
190
+ Whether to normalize the image.
191
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
192
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
193
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
194
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
195
+ `True`.
196
+ return_tensors (`str` or `TensorType`, *optional*):
197
+ The type of tensors to return. Can be one of:
198
+ - Unset: Return a list of `np.ndarray`.
199
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
200
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
201
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
202
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
203
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
204
+ The channel dimension format for the output image. Can be one of:
205
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
206
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
207
+ - Unset: Use the channel dimension format of the input image.
208
+ input_data_format (`ChannelDimension` or `str`, *optional*):
209
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
210
+ from the input image. Can be one of:
211
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
212
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
213
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
214
+ """
215
+ do_resize = do_resize if do_resize is not None else self.do_resize
216
+ size = size if size is not None else self.size
217
+ size = get_size_dict(size, param_name="size", default_to_square=False)
218
+ resample = resample if resample is not None else self.resample
219
+ # do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
220
+ # crop_size = crop_size if crop_size is not None else self.crop_size
221
+ # crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True)
222
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
223
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
224
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
225
+ image_mean = image_mean if image_mean is not None else self.image_mean
226
+ image_std = image_std if image_std is not None else self.image_std
227
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
228
+
229
+
230
+ images = make_list_of_images(images)
231
+
232
+ if not valid_images(images):
233
+ raise ValueError(
234
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
235
+ "torch.Tensor, tf.Tensor or jax.ndarray."
236
+ )
237
+
238
+ if do_resize and size is None:
239
+ raise ValueError("Size must be specified if do_resize is True.")
240
+
241
+ if do_rescale and rescale_factor is None:
242
+ raise ValueError("Rescale factor must be specified if do_rescale is True.")
243
+
244
+ if do_normalize and (image_mean is None or image_std is None):
245
+ raise ValueError("Image mean and std must be specified if do_normalize is True.")
246
+
247
+ # PIL RGBA images are converted to RGB
248
+ if do_convert_rgb:
249
+ images = [convert_to_rgb(image) for image in images]
250
+
251
+ # All transformations expect numpy arrays.
252
+ images = [to_numpy_array(image) for image in images]
253
+
254
+ if is_scaled_image(images[0]) and do_rescale:
255
+ logger.warning_once(
256
+ "It looks like you are trying to rescale already rescaled images. If the input"
257
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
258
+ )
259
+
260
+ # if input_data_format is None:
261
+ # # We assume that all images have the same channel dimension format.
262
+ # input_data_format = infer_channel_dimension_format(images[0])
263
+
264
+ if do_resize:
265
+ images = [self.resize(image=image, size=size, resample=resample) for image in images]
266
+
267
+ if do_rescale:
268
+ images = [rescale(image=image, scale=rescale_factor) for image in images]
269
+
270
+ if do_normalize:
271
+ output_images = []
272
+ for image in images:
273
+ if get_channel_dimension_axis(image) == 0:
274
+ image = image.transpose((1, 2, 0))
275
+ if image.shape[-1] == 1:
276
+ image = np.dstack((image, image, image))
277
+ output_images.append(image)
278
+ images = output_images
279
+ # for image in images:
280
+ # # print("image shape", image.shape)
281
+ # channel_axis = get_channel_dimension_axis(image)
282
+ # num_channels = image.shape[channel_axis]
283
+ # if num_channels != len(image_mean):
284
+ # print("image_mean", image_mean)
285
+ # print("channel_axis", channel_axis)
286
+ # print("num_channels", num_channels)
287
+ # print("image.shape", image.shape)
288
+ # raise ValueError(
289
+ # f"Number of channels in the image ({num_channels}) does not match the length of image mean "
290
+ # f"({len(image_mean)})."
291
+ # )
292
+
293
+ images = [normalize(image=image, mean=image_mean, std=image_std) for image in images]
294
+
295
+ images = [to_channel_dimension_format(image, data_format) for image in images]
296
+
297
+
298
+ data = {"pixel_values": images}
299
+ return BatchFeature(data=data, tensor_type=return_tensors)
dam/model/multimodal_encoder/siglip/modeling_siglip.py ADDED
@@ -0,0 +1,1269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch Siglip model."""
16
+
17
+
18
+ import math
19
+ import warnings
20
+ from dataclasses import dataclass
21
+ from typing import Any, Optional, Tuple, Union, List
22
+
23
+ import numpy as np
24
+ import torch
25
+ import torch.utils.checkpoint
26
+ from torch import nn
27
+ from torch.nn.init import _calculate_fan_in_and_fan_out
28
+
29
+ from transformers.activations import ACT2FN
30
+ # from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
31
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
32
+ from transformers.modeling_utils import PreTrainedModel
33
+ from transformers.utils import (
34
+ ModelOutput,
35
+ add_start_docstrings,
36
+ add_start_docstrings_to_model_forward,
37
+ logging,
38
+ replace_return_docstrings,
39
+ )
40
+ from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
41
+
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+ # _CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
46
+
47
+ # SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
48
+ # "google/siglip-base-patch16-224",
49
+ # # See all SigLIP models at https://huggingface.co/models?filter=siglip
50
+ # ]
51
+
52
+
53
+ def _trunc_normal_(tensor, mean, std, a, b):
54
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
55
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
56
+ def norm_cdf(x):
57
+ # Computes standard normal cumulative distribution function
58
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
59
+
60
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
61
+ warnings.warn(
62
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
63
+ "The distribution of values may be incorrect.",
64
+ stacklevel=2,
65
+ )
66
+
67
+ # Values are generated by using a truncated uniform distribution and
68
+ # then using the inverse CDF for the normal distribution.
69
+ # Get upper and lower cdf values
70
+ l = norm_cdf((a - mean) / std)
71
+ u = norm_cdf((b - mean) / std)
72
+
73
+ # Uniformly fill tensor with values from [l, u], then translate to
74
+ # [2l-1, 2u-1].
75
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
76
+
77
+ # Use inverse cdf transform for normal distribution to get truncated
78
+ # standard normal
79
+ tensor.erfinv_()
80
+
81
+ # Transform to proper mean, std
82
+ tensor.mul_(std * math.sqrt(2.0))
83
+ tensor.add_(mean)
84
+
85
+ # Clamp to ensure it's in the proper range
86
+ tensor.clamp_(min=a, max=b)
87
+
88
+
89
+ def trunc_normal_tf_(
90
+ tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
91
+ ) -> torch.Tensor:
92
+ """Fills the input Tensor with values drawn from a truncated
93
+ normal distribution. The values are effectively drawn from the
94
+ normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
95
+ with values outside :math:`[a, b]` redrawn until they are within
96
+ the bounds. The method used for generating the random values works
97
+ best when :math:`a \\leq \text{mean} \\leq b`.
98
+
99
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
100
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
101
+ and the result is subsquently scaled and shifted by the mean and std args.
102
+
103
+ Args:
104
+ tensor: an n-dimensional `torch.Tensor`
105
+ mean: the mean of the normal distribution
106
+ std: the standard deviation of the normal distribution
107
+ a: the minimum cutoff value
108
+ b: the maximum cutoff value
109
+ """
110
+ with torch.no_grad():
111
+ _trunc_normal_(tensor, 0, 1.0, a, b)
112
+ tensor.mul_(std).add_(mean)
113
+
114
+
115
+ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
116
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
117
+ if mode == "fan_in":
118
+ denom = fan_in
119
+ elif mode == "fan_out":
120
+ denom = fan_out
121
+ elif mode == "fan_avg":
122
+ denom = (fan_in + fan_out) / 2
123
+
124
+ variance = scale / denom
125
+
126
+ if distribution == "truncated_normal":
127
+ # constant is stddev of standard normal truncated to (-2, 2)
128
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
129
+ elif distribution == "normal":
130
+ with torch.no_grad():
131
+ tensor.normal_(std=math.sqrt(variance))
132
+ elif distribution == "uniform":
133
+ bound = math.sqrt(3 * variance)
134
+ with torch.no_grad():
135
+ tensor.uniform_(-bound, bound)
136
+ else:
137
+ raise ValueError(f"invalid distribution {distribution}")
138
+
139
+
140
+ def lecun_normal_(tensor):
141
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
142
+
143
+
144
+ def default_flax_embed_init(tensor):
145
+ variance_scaling_(tensor, mode="fan_in", distribution="normal")
146
+
147
+
148
+ @dataclass
149
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip
150
+ class SiglipVisionModelOutput(ModelOutput):
151
+ """
152
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
153
+
154
+ Args:
155
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
156
+ The image embeddings obtained by applying the projection layer to the pooler_output.
157
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
158
+ Sequence of hidden-states at the output of the last layer of the model.
159
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
160
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
161
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
162
+
163
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
164
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
165
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
166
+ sequence_length)`.
167
+
168
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
169
+ heads.
170
+ """
171
+
172
+ image_embeds: Optional[torch.FloatTensor] = None
173
+ last_hidden_state: torch.FloatTensor = None
174
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
175
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
176
+
177
+
178
+ @dataclass
179
+ # Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip
180
+ class SiglipTextModelOutput(ModelOutput):
181
+ """
182
+ Base class for text model's outputs that also contains a pooling of the last hidden states.
183
+
184
+ Args:
185
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
186
+ The text embeddings obtained by applying the projection layer to the pooler_output.
187
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
188
+ Sequence of hidden-states at the output of the last layer of the model.
189
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
190
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
191
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
192
+
193
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
194
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
195
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
196
+ sequence_length)`.
197
+
198
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
199
+ heads.
200
+ """
201
+
202
+ text_embeds: Optional[torch.FloatTensor] = None
203
+ last_hidden_state: torch.FloatTensor = None
204
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
205
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
206
+
207
+
208
+ @dataclass
209
+ # Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip
210
+ class SiglipOutput(ModelOutput):
211
+ """
212
+ Args:
213
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
214
+ Contrastive loss for image-text similarity.
215
+ logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
216
+ The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
217
+ similarity scores.
218
+ logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
219
+ The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
220
+ similarity scores.
221
+ text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
222
+ The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].
223
+ image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
224
+ The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`].
225
+ text_model_output(`BaseModelOutputWithPooling`):
226
+ The output of the [`SiglipTextModel`].
227
+ vision_model_output(`BaseModelOutputWithPooling`):
228
+ The output of the [`SiglipVisionModel`].
229
+ """
230
+
231
+ loss: Optional[torch.FloatTensor] = None
232
+ logits_per_image: torch.FloatTensor = None
233
+ logits_per_text: torch.FloatTensor = None
234
+ text_embeds: torch.FloatTensor = None
235
+ image_embeds: torch.FloatTensor = None
236
+ text_model_output: BaseModelOutputWithPooling = None
237
+ vision_model_output: BaseModelOutputWithPooling = None
238
+
239
+ def to_tuple(self) -> Tuple[Any]:
240
+ return tuple(
241
+ self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
242
+ for k in self.keys()
243
+ )
244
+
245
+
246
+ class SiglipVisionEmbeddings(nn.Module):
247
+ def __init__(self, config: SiglipVisionConfig):
248
+ super().__init__()
249
+ self.config = config
250
+ self.embed_dim = config.hidden_size
251
+ self.image_size = config.image_size
252
+ self.patch_size = config.patch_size
253
+
254
+ self.patch_embedding = nn.Conv2d(
255
+ in_channels=config.num_channels,
256
+ out_channels=self.embed_dim,
257
+ kernel_size=self.patch_size,
258
+ stride=self.patch_size,
259
+ padding="valid",
260
+ )
261
+
262
+ # print(f"Number of mask channels: {config.num_mask_channels}")
263
+ if config.num_mask_channels:
264
+ # Mask should have the same output shape to be added.
265
+ # Currently we have bias in this embedding (so that mask vs no mask are different).
266
+ self.mask_patch_embedding = nn.Conv2d(
267
+ in_channels=config.num_mask_channels,
268
+ out_channels=self.embed_dim,
269
+ kernel_size=self.patch_size,
270
+ stride=self.patch_size,
271
+ padding="valid"
272
+ )
273
+ self.mask_patch_embedding.use_zero_init = True
274
+ else:
275
+ self.mask_patch_embedding = None
276
+
277
+ self.num_patches = (self.image_size // self.patch_size) ** 2
278
+ self.num_positions = self.num_patches
279
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
280
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
281
+
282
+ def forward(self, pixel_values: torch.FloatTensor, additional_position_embedding: Optional[torch.Tensor] = None, additional_embedding_mode: Optional[str] = None) -> torch.Tensor:
283
+ if self.mask_patch_embedding is None:
284
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
285
+ else:
286
+ # Comment this out if you want to encode both images without mask channel and with mask channel.
287
+ # However, if different samples in the batch have different number of channels, this is not applicable.
288
+ # assert pixel_values.size(1) == 4, f"Input does not have a mask channel, shape: {pixel_values.shape}"
289
+ patch_embeds = self.patch_embedding(pixel_values[:, :3, ...]) # shape = [*, width, grid, grid]
290
+ if pixel_values.size(1) == 4:
291
+ patch_embeds = patch_embeds + self.mask_patch_embedding(pixel_values[:, 3:4, ...])
292
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
293
+
294
+ if additional_position_embedding is not None:
295
+ if additional_embedding_mode == "add":
296
+ embeddings = embeddings + self.position_embedding(self.position_ids)
297
+ embeddings = embeddings + additional_position_embedding
298
+ elif additional_embedding_mode == "replace":
299
+ # The original positional embedding is not used (multiplied by zero to ensure all parameters are used to be safe)
300
+ embeddings = embeddings + self.position_embedding(self.position_ids) * 0.
301
+ embeddings = embeddings + additional_position_embedding
302
+ else:
303
+ raise ValueError(f"additional_embedding_mode should be either 'add' or 'replace', got {additional_embedding_mode}")
304
+ else:
305
+ # Without additional position embedding
306
+ embeddings = embeddings + self.position_embedding(self.position_ids)
307
+ # print("No additional position embedding")
308
+ return embeddings
309
+
310
+
311
+ # Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip
312
+ class SiglipTextEmbeddings(nn.Module):
313
+ def __init__(self, config: SiglipTextConfig):
314
+ super().__init__()
315
+ embed_dim = config.hidden_size
316
+
317
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
318
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
319
+
320
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
321
+ self.register_buffer(
322
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
323
+ )
324
+
325
+ def forward(
326
+ self,
327
+ input_ids: Optional[torch.LongTensor] = None,
328
+ position_ids: Optional[torch.LongTensor] = None,
329
+ inputs_embeds: Optional[torch.FloatTensor] = None,
330
+ ) -> torch.Tensor:
331
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
332
+
333
+ if position_ids is None:
334
+ position_ids = self.position_ids[:, :seq_length]
335
+
336
+ if inputs_embeds is None:
337
+ inputs_embeds = self.token_embedding(input_ids)
338
+
339
+ position_embeddings = self.position_embedding(position_ids)
340
+ embeddings = inputs_embeds + position_embeddings
341
+
342
+ return embeddings
343
+
344
+
345
+ class SiglipAttention(nn.Module):
346
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
347
+
348
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
349
+ def __init__(self, config):
350
+ super().__init__()
351
+ self.config = config
352
+ self.embed_dim = config.hidden_size
353
+ self.num_heads = config.num_attention_heads
354
+ self.head_dim = self.embed_dim // self.num_heads
355
+ if self.head_dim * self.num_heads != self.embed_dim:
356
+ raise ValueError(
357
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
358
+ f" {self.num_heads})."
359
+ )
360
+ self.scale = self.head_dim**-0.5
361
+ self.dropout = config.attention_dropout
362
+
363
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
364
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
365
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
366
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
367
+
368
+ def forward(
369
+ self,
370
+ hidden_states: torch.Tensor,
371
+ attention_mask: Optional[torch.Tensor] = None,
372
+ output_attentions: Optional[bool] = False,
373
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
374
+ """Input shape: Batch x Time x Channel"""
375
+
376
+ batch_size, q_len, _ = hidden_states.size()
377
+
378
+ query_states = self.q_proj(hidden_states)
379
+ key_states = self.k_proj(hidden_states)
380
+ value_states = self.v_proj(hidden_states)
381
+
382
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
383
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
384
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
385
+
386
+ k_v_seq_len = key_states.shape[-2]
387
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
388
+
389
+ if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
390
+ raise ValueError(
391
+ f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
392
+ f" {attn_weights.size()}"
393
+ )
394
+
395
+ if attention_mask is not None:
396
+ if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
397
+ raise ValueError(
398
+ f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
399
+ )
400
+ attn_weights = attn_weights + attention_mask
401
+
402
+ # upcast attention to fp32
403
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
404
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
405
+ attn_output = torch.matmul(attn_weights, value_states)
406
+
407
+ if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
408
+ raise ValueError(
409
+ f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
410
+ f" {attn_output.size()}"
411
+ )
412
+
413
+ attn_output = attn_output.transpose(1, 2).contiguous()
414
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
415
+
416
+ attn_output = self.out_proj(attn_output)
417
+
418
+ return attn_output, attn_weights
419
+
420
+
421
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
422
+ class SiglipMLP(nn.Module):
423
+ def __init__(self, config):
424
+ super().__init__()
425
+ self.config = config
426
+ self.activation_fn = ACT2FN[config.hidden_act]
427
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
428
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
429
+
430
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
431
+ hidden_states = self.fc1(hidden_states)
432
+ hidden_states = self.activation_fn(hidden_states)
433
+ hidden_states = self.fc2(hidden_states)
434
+ return hidden_states
435
+
436
+
437
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip
438
+ class SiglipEncoderLayer(nn.Module):
439
+ def __init__(self, config: SiglipConfig):
440
+ super().__init__()
441
+ self.embed_dim = config.hidden_size
442
+ self.self_attn = SiglipAttention(config)
443
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
444
+ self.mlp = SiglipMLP(config)
445
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
446
+
447
+ # Ignore copy
448
+ def forward(
449
+ self,
450
+ hidden_states: torch.Tensor,
451
+ attention_mask: torch.Tensor,
452
+ output_attentions: Optional[bool] = False,
453
+ ) -> Tuple[torch.FloatTensor]:
454
+ """
455
+ Args:
456
+ hidden_states (`torch.FloatTensor`):
457
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
458
+ attention_mask (`torch.FloatTensor`):
459
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
460
+ output_attentions (`bool`, *optional*, defaults to `False`):
461
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
462
+ returned tensors for more detail.
463
+ """
464
+ residual = hidden_states
465
+
466
+ hidden_states = self.layer_norm1(hidden_states)
467
+ hidden_states, attn_weights = self.self_attn(
468
+ hidden_states=hidden_states,
469
+ attention_mask=attention_mask,
470
+ output_attentions=output_attentions,
471
+ )
472
+ hidden_states = residual + hidden_states
473
+
474
+ residual = hidden_states
475
+ hidden_states = self.layer_norm2(hidden_states)
476
+ hidden_states = self.mlp(hidden_states)
477
+ hidden_states = residual + hidden_states
478
+
479
+ outputs = (hidden_states,)
480
+
481
+ if output_attentions:
482
+ outputs += (attn_weights,)
483
+
484
+ return outputs
485
+
486
+
487
+ class SiglipPreTrainedModel(PreTrainedModel):
488
+ """
489
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
490
+ models.
491
+ """
492
+
493
+ config_class = SiglipConfig
494
+ base_model_prefix = "siglip"
495
+ supports_gradient_checkpointing = True
496
+
497
+ def _init_weights(self, module):
498
+ """Initialize the weights"""
499
+ if isinstance(module, SiglipVisionEmbeddings):
500
+ width = (
501
+ self.config.vision_config.hidden_size
502
+ if isinstance(self.config, SiglipConfig)
503
+ else self.config.hidden_size
504
+ )
505
+ nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
506
+ elif isinstance(module, nn.Embedding):
507
+ default_flax_embed_init(module.weight)
508
+ elif isinstance(module, SiglipAttention):
509
+ nn.init.xavier_uniform_(module.q_proj.weight)
510
+ nn.init.xavier_uniform_(module.k_proj.weight)
511
+ nn.init.xavier_uniform_(module.v_proj.weight)
512
+ nn.init.xavier_uniform_(module.out_proj.weight)
513
+ nn.init.zeros_(module.q_proj.bias)
514
+ nn.init.zeros_(module.k_proj.bias)
515
+ nn.init.zeros_(module.v_proj.bias)
516
+ nn.init.zeros_(module.out_proj.bias)
517
+ elif isinstance(module, SiglipMLP):
518
+ nn.init.xavier_uniform_(module.fc1.weight)
519
+ nn.init.xavier_uniform_(module.fc2.weight)
520
+ nn.init.normal_(module.fc1.bias, std=1e-6)
521
+ nn.init.normal_(module.fc2.bias, std=1e-6)
522
+ elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
523
+ nn.init.xavier_uniform_(module.probe.data)
524
+ nn.init.xavier_uniform_(module.attention.in_proj_weight.data)
525
+ nn.init.zeros_(module.attention.in_proj_bias.data)
526
+ elif isinstance(module, SiglipModel):
527
+ logit_scale_init = torch.log(torch.tensor(1.0))
528
+ module.logit_scale.data.fill_(logit_scale_init)
529
+ module.logit_bias.data.zero_()
530
+ elif isinstance(module, nn.Conv2d) and getattr(module, "use_zero_init", False):
531
+ param_list = [module.weight]
532
+ if module.bias is not None:
533
+ param_list += [module.bias]
534
+ # This is used in mask patch embedding
535
+ #
536
+
537
+ # with deepspeed.zero.GatheredParameters(param_list, modifier_rank=0):
538
+ # for param in param_list:
539
+ # nn.init.zeros_(param)
540
+
541
+ for param in param_list:
542
+ nn.init.zeros_(param)
543
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
544
+ lecun_normal_(module.weight)
545
+ if module.bias is not None:
546
+ nn.init.zeros_(module.bias)
547
+ elif isinstance(module, nn.LayerNorm):
548
+ module.bias.data.zero_()
549
+ module.weight.data.fill_(1.0)
550
+
551
+
552
+ SIGLIP_START_DOCSTRING = r"""
553
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
554
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
555
+ etc.)
556
+
557
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
558
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
559
+ and behavior.
560
+
561
+ Parameters:
562
+ config ([`SiglipConfig`]): Model configuration class with all the parameters of the model.
563
+ Initializing with a config file does not load the weights associated with the model, only the
564
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
565
+ """
566
+
567
+ SIGLIP_TEXT_INPUTS_DOCSTRING = r"""
568
+ Args:
569
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
570
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
571
+ it.
572
+
573
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
574
+ [`PreTrainedTokenizer.__call__`] for details.
575
+
576
+ [What are input IDs?](../glossary#input-ids)
577
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
578
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
579
+
580
+ - 1 for tokens that are **not masked**,
581
+ - 0 for tokens that are **masked**.
582
+
583
+ [What are attention masks?](../glossary#attention-mask)
584
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
585
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
586
+ config.max_position_embeddings - 1]`.
587
+
588
+ [What are position IDs?](../glossary#position-ids)
589
+ output_attentions (`bool`, *optional*):
590
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
591
+ tensors for more detail.
592
+ output_hidden_states (`bool`, *optional*):
593
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
594
+ more detail.
595
+ return_dict (`bool`, *optional*):
596
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
597
+ """
598
+
599
+ SIGLIP_VISION_INPUTS_DOCSTRING = r"""
600
+ Args:
601
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
602
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
603
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
604
+ output_attentions (`bool`, *optional*):
605
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
606
+ tensors for more detail.
607
+ output_hidden_states (`bool`, *optional*):
608
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
609
+ more detail.
610
+ return_dict (`bool`, *optional*):
611
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
612
+ """
613
+
614
+ SIGLIP_INPUTS_DOCSTRING = r"""
615
+ Args:
616
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
617
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
618
+ it.
619
+
620
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
621
+ [`PreTrainedTokenizer.__call__`] for details.
622
+
623
+ [What are input IDs?](../glossary#input-ids)
624
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
625
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
626
+
627
+ - 1 for tokens that are **not masked**,
628
+ - 0 for tokens that are **masked**.
629
+
630
+ [What are attention masks?](../glossary#attention-mask)
631
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
632
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
633
+ config.max_position_embeddings - 1]`.
634
+
635
+ [What are position IDs?](../glossary#position-ids)
636
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
637
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
638
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
639
+ return_loss (`bool`, *optional*):
640
+ Whether or not to return the contrastive loss.
641
+ output_attentions (`bool`, *optional*):
642
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
643
+ tensors for more detail.
644
+ output_hidden_states (`bool`, *optional*):
645
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
646
+ more detail.
647
+ return_dict (`bool`, *optional*):
648
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
649
+ """
650
+
651
+
652
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip
653
+ class SiglipEncoder(nn.Module):
654
+ """
655
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
656
+ [`SiglipEncoderLayer`].
657
+
658
+ Args:
659
+ config: SiglipConfig
660
+ """
661
+
662
+ def __init__(self, config: SiglipConfig):
663
+ super().__init__()
664
+ self.config = config
665
+ self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
666
+ self.gradient_checkpointing = False
667
+
668
+ # Ignore copy
669
+ def forward(
670
+ self,
671
+ inputs_embeds,
672
+ attention_mask: Optional[torch.Tensor] = None,
673
+ output_attentions: Optional[bool] = None,
674
+ output_hidden_states: Optional[bool] = None,
675
+ return_dict: Optional[bool] = None,
676
+ context_provider_layers: Optional[nn.ModuleList] = None,
677
+ contexts: Optional[List[torch.Tensor]] = None,
678
+ cross_attention_mask: Optional[torch.Tensor] = None,
679
+ ) -> Union[Tuple, BaseModelOutput]:
680
+ r"""
681
+ Args:
682
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
683
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
684
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
685
+ than the model's internal embedding lookup matrix.
686
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
687
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
688
+
689
+ - 1 for tokens that are **not masked**,
690
+ - 0 for tokens that are **masked**.
691
+
692
+ [What are attention masks?](../glossary#attention-mask)
693
+ output_attentions (`bool`, *optional*):
694
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
695
+ returned tensors for more detail.
696
+ output_hidden_states (`bool`, *optional*):
697
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
698
+ for more detail.
699
+ context_provider_layers (nn.ModuleList): ModuleList of context provider layers.
700
+ contexts: List of torch.Tensor for context (for KV in cross-attention).
701
+ cross_attention_mask (`torch.Tensor` of shape `(batch_size, q_sequence_length, kv_sequence_length)`, *optional*): mask for cross-attention.
702
+ return_dict (`bool`, *optional*):
703
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
704
+ """
705
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
706
+ output_hidden_states = (
707
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
708
+ )
709
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
710
+
711
+ encoder_states = () if output_hidden_states else None
712
+ all_attentions = () if output_attentions else None
713
+
714
+ hidden_states = inputs_embeds
715
+ for layer_index, encoder_layer in enumerate(self.layers):
716
+ if output_hidden_states:
717
+ encoder_states = encoder_states + (hidden_states,)
718
+ if self.gradient_checkpointing and self.training:
719
+ layer_outputs = self._gradient_checkpointing_func(
720
+ encoder_layer.__call__,
721
+ hidden_states,
722
+ attention_mask,
723
+ output_attentions,
724
+ )
725
+ else:
726
+ layer_outputs = encoder_layer(
727
+ hidden_states,
728
+ attention_mask,
729
+ output_attentions=output_attentions,
730
+ )
731
+
732
+ hidden_states = layer_outputs[0]
733
+
734
+ if output_attentions:
735
+ all_attentions = all_attentions + (layer_outputs[1],)
736
+
737
+ if context_provider_layers:
738
+ # Right now contexts is passed as the encoder_hidden_states (the output hidden_states of the context ViT).
739
+ context_provider_layer = context_provider_layers[layer_index]
740
+ if context_provider_layer is not None:
741
+ if self.gradient_checkpointing and self.training:
742
+ layer_outputs = self._gradient_checkpointing_func(
743
+ context_provider_layer.__call__,
744
+ hidden_states,
745
+ contexts,
746
+ cross_attention_mask,
747
+ output_attentions,
748
+ )
749
+ else:
750
+ layer_outputs = context_provider_layer(
751
+ hidden_states,
752
+ contexts,
753
+ cross_attention_mask,
754
+ output_attentions=output_attentions,
755
+ )
756
+
757
+ hidden_states = layer_outputs[0]
758
+
759
+ if output_hidden_states:
760
+ encoder_states = encoder_states + (hidden_states,)
761
+
762
+ if not return_dict:
763
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
764
+ return BaseModelOutput(
765
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
766
+ )
767
+
768
+
769
+ class SiglipTextTransformer(nn.Module):
770
+ def __init__(self, config: SiglipTextConfig):
771
+ super().__init__()
772
+ self.config = config
773
+ embed_dim = config.hidden_size
774
+ self.embeddings = SiglipTextEmbeddings(config)
775
+ self.encoder = SiglipEncoder(config)
776
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
777
+
778
+ self.head = nn.Linear(embed_dim, embed_dim)
779
+
780
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
781
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
782
+ def forward(
783
+ self,
784
+ input_ids: Optional[torch.Tensor] = None,
785
+ attention_mask: Optional[torch.Tensor] = None,
786
+ position_ids: Optional[torch.Tensor] = None,
787
+ output_attentions: Optional[bool] = None,
788
+ output_hidden_states: Optional[bool] = None,
789
+ return_dict: Optional[bool] = None,
790
+ **kwargs,
791
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
792
+ r"""
793
+ Returns:
794
+
795
+ """
796
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
797
+ output_hidden_states = (
798
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
799
+ )
800
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
801
+
802
+ if input_ids is None:
803
+ raise ValueError("You have to specify input_ids")
804
+
805
+ input_shape = input_ids.size()
806
+ input_ids = input_ids.view(-1, input_shape[-1])
807
+
808
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
809
+
810
+ # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model.
811
+ # expand attention_mask
812
+ # if attention_mask is not None:
813
+ # # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
814
+ # attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
815
+
816
+ encoder_outputs = self.encoder(
817
+ inputs_embeds=hidden_states,
818
+ attention_mask=attention_mask,
819
+ output_attentions=output_attentions,
820
+ output_hidden_states=output_hidden_states,
821
+ return_dict=return_dict,
822
+ **kwargs,
823
+ )
824
+
825
+ last_hidden_state = encoder_outputs[0]
826
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
827
+
828
+ # Assuming "sticky" EOS tokenization, last token is always EOS.
829
+ pooled_output = last_hidden_state[:, -1, :]
830
+ pooled_output = self.head(pooled_output)
831
+
832
+ if not return_dict:
833
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
834
+
835
+ return BaseModelOutputWithPooling(
836
+ last_hidden_state=last_hidden_state,
837
+ pooler_output=pooled_output,
838
+ hidden_states=encoder_outputs.hidden_states,
839
+ attentions=encoder_outputs.attentions,
840
+ )
841
+
842
+
843
+ @add_start_docstrings(
844
+ """The text model from SigLIP without any head or projection on top.""",
845
+ SIGLIP_START_DOCSTRING,
846
+ )
847
+ class SiglipTextModel(SiglipPreTrainedModel):
848
+ config_class = SiglipTextConfig
849
+
850
+ _no_split_modules = ["SiglipTextEmbeddings", "SiglipEncoderLayer"]
851
+
852
+ def __init__(self, config: SiglipTextConfig):
853
+ super().__init__(config)
854
+ self.text_model = SiglipTextTransformer(config)
855
+ # Initialize weights and apply final processing
856
+ self.post_init()
857
+
858
+ def get_input_embeddings(self) -> nn.Module:
859
+ return self.text_model.embeddings.token_embedding
860
+
861
+ def set_input_embeddings(self, value):
862
+ self.text_model.embeddings.token_embedding = value
863
+
864
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
865
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
866
+ def forward(
867
+ self,
868
+ input_ids: Optional[torch.Tensor] = None,
869
+ attention_mask: Optional[torch.Tensor] = None,
870
+ position_ids: Optional[torch.Tensor] = None,
871
+ output_attentions: Optional[bool] = None,
872
+ output_hidden_states: Optional[bool] = None,
873
+ return_dict: Optional[bool] = None,
874
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
875
+ r"""
876
+ Returns:
877
+
878
+ Examples:
879
+
880
+ ```python
881
+ >>> from transformers import AutoTokenizer, SiglipTextModel
882
+
883
+ >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224")
884
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
885
+
886
+ >>> # important: make sure to set padding="max_length" as that's how the model was trained
887
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
888
+
889
+ >>> outputs = model(**inputs)
890
+ >>> last_hidden_state = outputs.last_hidden_state
891
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
892
+ ```"""
893
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
894
+
895
+ return self.text_model(
896
+ input_ids=input_ids,
897
+ attention_mask=attention_mask,
898
+ position_ids=position_ids,
899
+ output_attentions=output_attentions,
900
+ output_hidden_states=output_hidden_states,
901
+ return_dict=return_dict,
902
+ )
903
+
904
+
905
+ class SiglipVisionTransformer(nn.Module):
906
+ def __init__(self, config: SiglipVisionConfig):
907
+ super().__init__()
908
+ self.config = config
909
+ embed_dim = config.hidden_size
910
+
911
+ self.embeddings = SiglipVisionEmbeddings(config)
912
+ self.encoder = SiglipEncoder(config)
913
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
914
+ self.head = SiglipMultiheadAttentionPoolingHead(config)
915
+
916
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
917
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
918
+ def forward(
919
+ self,
920
+ pixel_values,
921
+ output_attentions: Optional[bool] = None,
922
+ output_hidden_states: Optional[bool] = None,
923
+ return_dict: Optional[bool] = None,
924
+ **kwargs
925
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
926
+ r"""
927
+ Returns:
928
+
929
+ """
930
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
931
+ output_hidden_states = (
932
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
933
+ )
934
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
935
+
936
+ hidden_states = self.embeddings(pixel_values)
937
+
938
+ encoder_outputs = self.encoder(
939
+ inputs_embeds=hidden_states,
940
+ output_attentions=output_attentions,
941
+ output_hidden_states=output_hidden_states,
942
+ return_dict=return_dict,
943
+ **kwargs
944
+ )
945
+
946
+ last_hidden_state = encoder_outputs[0]
947
+ last_hidden_state = self.post_layernorm(last_hidden_state)
948
+
949
+ pooled_output = self.head(last_hidden_state)
950
+
951
+ if not return_dict:
952
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
953
+
954
+ return BaseModelOutputWithPooling(
955
+ last_hidden_state=last_hidden_state,
956
+ pooler_output=pooled_output,
957
+ hidden_states=encoder_outputs.hidden_states,
958
+ attentions=encoder_outputs.attentions,
959
+ )
960
+
961
+
962
+ class SiglipMultiheadAttentionPoolingHead(nn.Module):
963
+ """Multihead Attention Pooling."""
964
+
965
+ def __init__(self, config: SiglipVisionConfig):
966
+ super().__init__()
967
+
968
+ self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
969
+ self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
970
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
971
+ self.mlp = SiglipMLP(config)
972
+
973
+ def forward(self, hidden_state):
974
+ batch_size = hidden_state.shape[0]
975
+ probe = self.probe.repeat(batch_size, 1, 1)
976
+
977
+ hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
978
+
979
+ residual = hidden_state
980
+ hidden_state = self.layernorm(hidden_state)
981
+ hidden_state = residual + self.mlp(hidden_state)
982
+
983
+ return hidden_state[:, 0]
984
+
985
+
986
+ @add_start_docstrings(
987
+ """The vision model from SigLIP without any head or projection on top.""",
988
+ SIGLIP_START_DOCSTRING,
989
+ )
990
+ class SiglipVisionModel(SiglipPreTrainedModel):
991
+ config_class = SiglipVisionConfig
992
+ main_input_name = "pixel_values"
993
+
994
+ def __init__(self, config: SiglipVisionConfig):
995
+ super().__init__(config)
996
+
997
+ self.vision_model = SiglipVisionTransformer(config)
998
+
999
+ # Initialize weights and apply final processing
1000
+ self.post_init()
1001
+
1002
+ def get_input_embeddings(self) -> nn.Module:
1003
+ return self.vision_model.embeddings.patch_embedding
1004
+
1005
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1006
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
1007
+ def forward(
1008
+ self,
1009
+ pixel_values,
1010
+ output_attentions: Optional[bool] = None,
1011
+ output_hidden_states: Optional[bool] = None,
1012
+ return_dict: Optional[bool] = None,
1013
+ **kwargs
1014
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
1015
+ r"""
1016
+ Returns:
1017
+
1018
+ Examples:
1019
+
1020
+ ```python
1021
+ >>> from PIL import Image
1022
+ >>> import requests
1023
+ >>> from transformers import AutoProcessor, SiglipVisionModel
1024
+
1025
+ >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
1026
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1027
+
1028
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1029
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1030
+
1031
+ >>> inputs = processor(images=image, return_tensors="pt")
1032
+
1033
+ >>> outputs = model(**inputs)
1034
+ >>> last_hidden_state = outputs.last_hidden_state
1035
+ >>> pooled_output = outputs.pooler_output # pooled features
1036
+ ```"""
1037
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1038
+
1039
+ return self.vision_model(
1040
+ pixel_values=pixel_values,
1041
+ output_attentions=output_attentions,
1042
+ output_hidden_states=output_hidden_states,
1043
+ return_dict=return_dict,
1044
+ **kwargs
1045
+ )
1046
+
1047
+
1048
+ @add_start_docstrings(SIGLIP_START_DOCSTRING)
1049
+ class SiglipModel(SiglipPreTrainedModel):
1050
+ config_class = SiglipConfig
1051
+
1052
+ def __init__(self, config: SiglipConfig):
1053
+ super().__init__(config)
1054
+
1055
+ if not isinstance(config.text_config, SiglipTextConfig):
1056
+ raise ValueError(
1057
+ "config.text_config is expected to be of type SiglipTextConfig but is of type"
1058
+ f" {type(config.text_config)}."
1059
+ )
1060
+
1061
+ if not isinstance(config.vision_config, SiglipVisionConfig):
1062
+ raise ValueError(
1063
+ "config.vision_config is expected to be of type SiglipVisionConfig but is of type"
1064
+ f" {type(config.vision_config)}."
1065
+ )
1066
+
1067
+ text_config = config.text_config
1068
+ vision_config = config.vision_config
1069
+
1070
+ self.text_model = SiglipTextTransformer(text_config)
1071
+ self.vision_model = SiglipVisionTransformer(vision_config)
1072
+
1073
+ self.logit_scale = nn.Parameter(torch.randn(1))
1074
+ self.logit_bias = nn.Parameter(torch.randn(1))
1075
+
1076
+ # Initialize weights and apply final processing
1077
+ self.post_init()
1078
+
1079
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
1080
+ def get_text_features(
1081
+ self,
1082
+ input_ids: Optional[torch.Tensor] = None,
1083
+ attention_mask: Optional[torch.Tensor] = None,
1084
+ position_ids: Optional[torch.Tensor] = None,
1085
+ output_attentions: Optional[bool] = None,
1086
+ output_hidden_states: Optional[bool] = None,
1087
+ return_dict: Optional[bool] = None,
1088
+ ) -> torch.FloatTensor:
1089
+ r"""
1090
+ Returns:
1091
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
1092
+ applying the projection layer to the pooled output of [`SiglipTextModel`].
1093
+
1094
+ Examples:
1095
+
1096
+ ```python
1097
+ >>> from transformers import AutoTokenizer, AutoModel
1098
+ >>> import torch
1099
+
1100
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1101
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
1102
+
1103
+ >>> # important: make sure to set padding="max_length" as that's how the model was trained
1104
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
1105
+ >>> with torch.no_grad():
1106
+ ... text_features = model.get_text_features(**inputs)
1107
+ ```"""
1108
+ # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1109
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1110
+ output_hidden_states = (
1111
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1112
+ )
1113
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1114
+
1115
+ text_outputs = self.text_model(
1116
+ input_ids=input_ids,
1117
+ attention_mask=attention_mask,
1118
+ position_ids=position_ids,
1119
+ output_attentions=output_attentions,
1120
+ output_hidden_states=output_hidden_states,
1121
+ return_dict=return_dict,
1122
+ )
1123
+
1124
+ pooled_output = text_outputs[1]
1125
+
1126
+ return pooled_output
1127
+
1128
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1129
+ def get_image_features(
1130
+ self,
1131
+ pixel_values: Optional[torch.FloatTensor] = None,
1132
+ output_attentions: Optional[bool] = None,
1133
+ output_hidden_states: Optional[bool] = None,
1134
+ return_dict: Optional[bool] = None,
1135
+ ) -> torch.FloatTensor:
1136
+ r"""
1137
+ Returns:
1138
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
1139
+ applying the projection layer to the pooled output of [`SiglipVisionModel`].
1140
+
1141
+ Examples:
1142
+
1143
+ ```python
1144
+ >>> from PIL import Image
1145
+ >>> import requests
1146
+ >>> from transformers import AutoProcessor, AutoModel
1147
+ >>> import torch
1148
+
1149
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1150
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1151
+
1152
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1153
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1154
+
1155
+ >>> inputs = processor(images=image, return_tensors="pt")
1156
+
1157
+ >>> with torch.no_grad():
1158
+ ... image_features = model.get_image_features(**inputs)
1159
+ ```"""
1160
+ # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components.
1161
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1162
+ output_hidden_states = (
1163
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1164
+ )
1165
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1166
+
1167
+ vision_outputs = self.vision_model(
1168
+ pixel_values=pixel_values,
1169
+ output_attentions=output_attentions,
1170
+ output_hidden_states=output_hidden_states,
1171
+ return_dict=return_dict,
1172
+ )
1173
+
1174
+ pooled_output = vision_outputs[1]
1175
+
1176
+ return pooled_output
1177
+
1178
+ @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING)
1179
+ @replace_return_docstrings(output_type=SiglipOutput, config_class=SiglipConfig)
1180
+ def forward(
1181
+ self,
1182
+ input_ids: Optional[torch.LongTensor] = None,
1183
+ pixel_values: Optional[torch.FloatTensor] = None,
1184
+ attention_mask: Optional[torch.Tensor] = None,
1185
+ position_ids: Optional[torch.LongTensor] = None,
1186
+ return_loss: Optional[bool] = None,
1187
+ output_attentions: Optional[bool] = None,
1188
+ output_hidden_states: Optional[bool] = None,
1189
+ return_dict: Optional[bool] = None,
1190
+ ) -> Union[Tuple, SiglipOutput]:
1191
+ r"""
1192
+ Returns:
1193
+
1194
+ Examples:
1195
+
1196
+ ```python
1197
+ >>> from PIL import Image
1198
+ >>> import requests
1199
+ >>> from transformers import AutoProcessor, AutoModel
1200
+ >>> import torch
1201
+
1202
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1203
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1204
+
1205
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1206
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1207
+
1208
+ >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
1209
+ >>> inputs = processor(text=texts, images=image, return_tensors="pt")
1210
+
1211
+ >>> with torch.no_grad():
1212
+ ... outputs = model(**inputs)
1213
+
1214
+ >>> logits_per_image = outputs.logits_per_image
1215
+ >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
1216
+ >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
1217
+ 31.9% that image 0 is 'a photo of 2 cats'
1218
+ ```"""
1219
+ # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1220
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1221
+ output_hidden_states = (
1222
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1223
+ )
1224
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1225
+
1226
+ vision_outputs = self.vision_model(
1227
+ pixel_values=pixel_values,
1228
+ output_attentions=output_attentions,
1229
+ output_hidden_states=output_hidden_states,
1230
+ return_dict=return_dict,
1231
+ )
1232
+
1233
+ text_outputs = self.text_model(
1234
+ input_ids=input_ids,
1235
+ attention_mask=attention_mask,
1236
+ position_ids=position_ids,
1237
+ output_attentions=output_attentions,
1238
+ output_hidden_states=output_hidden_states,
1239
+ return_dict=return_dict,
1240
+ )
1241
+
1242
+ image_embeds = vision_outputs[1]
1243
+ text_embeds = text_outputs[1]
1244
+
1245
+ # normalized features
1246
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
1247
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
1248
+
1249
+ # cosine similarity as logits
1250
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale.exp() + self.logit_bias
1251
+ logits_per_image = logits_per_text.t()
1252
+
1253
+ loss = None
1254
+ if return_loss:
1255
+ raise NotImplementedError("SigLIP loss to be implemented")
1256
+
1257
+ if not return_dict:
1258
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
1259
+ return ((loss,) + output) if loss is not None else output
1260
+
1261
+ return SiglipOutput(
1262
+ loss=loss,
1263
+ logits_per_image=logits_per_image,
1264
+ logits_per_text=logits_per_text,
1265
+ text_embeds=text_embeds,
1266
+ image_embeds=image_embeds,
1267
+ text_model_output=text_outputs,
1268
+ vision_model_output=vision_outputs,
1269
+ )
dam/model/multimodal_encoder/siglip/processing_siglip.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Image/Text processor class for SigLIP.
17
+ """
18
+
19
+ from typing import List, Optional, Union
20
+
21
+ from transformers.feature_extraction_utils import BatchFeature
22
+ from transformers.image_utils import ImageInput
23
+ from transformers.processing_utils import ProcessorMixin
24
+ from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
25
+ from transformers.utils import TensorType
26
+
27
+
28
+ class SiglipProcessor(ProcessorMixin):
29
+ r"""
30
+ Constructs a Siglip processor which wraps a Siglip image processor and a Siglip tokenizer into a single processor.
31
+
32
+ [`SiglipProcessor`] offers all the functionalities of [`SiglipImageProcessor`] and [`SiglipTokenizer`]. See the
33
+ [`~SiglipProcessor.__call__`] and [`~SiglipProcessor.decode`] for more information.
34
+
35
+ Args:
36
+ image_processor ([`SiglipImageProcessor`]):
37
+ The image processor is a required input.
38
+ tokenizer ([`SiglipTokenizer`]):
39
+ The tokenizer is a required input.
40
+ """
41
+
42
+ attributes = ["image_processor", "tokenizer"]
43
+ image_processor_class = "SiglipImageProcessor"
44
+ tokenizer_class = "SiglipTokenizer"
45
+
46
+ def __init__(self, image_processor, tokenizer):
47
+ super().__init__(image_processor, tokenizer)
48
+
49
+ def __call__(
50
+ self,
51
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
52
+ images: ImageInput = None,
53
+ padding: Union[bool, str, PaddingStrategy] = "max_length",
54
+ truncation: Union[bool, str, TruncationStrategy] = None,
55
+ max_length=None,
56
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
57
+ ) -> BatchFeature:
58
+ """
59
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
60
+ and `kwargs` arguments to SiglipTokenizer's [`~SiglipTokenizer.__call__`] if `text` is not `None` to encode
61
+ the text. To prepare the image(s), this method forwards the `images` argument to
62
+ SiglipImageProcessor's [`~SiglipImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
63
+ of the above two methods for more information.
64
+
65
+ Args:
66
+ text (`str`, `List[str]`, `List[List[str]]`):
67
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
68
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
69
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
70
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
71
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
72
+ tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
73
+ number of channels, H and W are image height and width.
74
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `max_length`):
75
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
76
+ index) among:
77
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
78
+ sequence if provided).
79
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
80
+ acceptable input length for the model if that argument is not provided.
81
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
82
+ lengths).
83
+ max_length (`int`, *optional*):
84
+ Maximum length of the returned list and optionally padding length (see above).
85
+ truncation (`bool`, *optional*):
86
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
87
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
88
+ If set, will return tensors of a particular framework. Acceptable values are:
89
+
90
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
91
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
92
+ - `'np'`: Return NumPy `np.ndarray` objects.
93
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
94
+
95
+ Returns:
96
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
97
+
98
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
99
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
100
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
101
+ `None`).
102
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
103
+ """
104
+
105
+ if text is None and images is None:
106
+ raise ValueError("You have to specify either text or images. Both cannot be none.")
107
+
108
+ if text is not None:
109
+ encoding = self.tokenizer(
110
+ text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
111
+ )
112
+
113
+ if images is not None:
114
+ image_features = self.image_processor(images, return_tensors=return_tensors)
115
+
116
+ if text is not None and images is not None:
117
+ encoding["pixel_values"] = image_features.pixel_values
118
+ return encoding
119
+ elif text is not None:
120
+ return encoding
121
+ else:
122
+ return BatchFeature(data=dict(**image_features), tensor_type=return_tensors)
123
+
124
+ def decode(self, *args, **kwargs):
125
+ """
126
+ This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
127
+ the docstring of this method for more information.
128
+ """
129
+ return self.tokenizer.decode(*args, **kwargs)
130
+
131
+ def batch_decode(self, *args, **kwargs):
132
+ """
133
+ This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
134
+ refer to the docstring of this method for more information.
135
+ """
136
+ return self.tokenizer.batch_decode(*args, **kwargs)
137
+
138
+ @property
139
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->Siglip, T5->Siglip
140
+ def model_input_names(self):
141
+ tokenizer_input_names = self.tokenizer.model_input_names
142
+ image_processor_input_names = self.image_processor.model_input_names
143
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
dam/model/multimodal_encoder/siglip/tokenization_siglip.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Tokenization class for SigLIP model."""
16
+
17
+ import os
18
+ import re
19
+ import string
20
+ import warnings
21
+ from shutil import copyfile
22
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
23
+
24
+ import sentencepiece as spm
25
+
26
+ from transformers.convert_slow_tokenizer import import_protobuf
27
+ from transformers.tokenization_utils import PreTrainedTokenizer
28
+ from transformers.tokenization_utils_base import AddedToken
29
+
30
+
31
+ if TYPE_CHECKING:
32
+ from transformers.tokenization_utils_base import TextInput
33
+ from transformers.utils import logging, requires_backends
34
+
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+ VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
39
+
40
+ PRETRAINED_VOCAB_FILES_MAP = {
41
+ "vocab_file": {
42
+ "google/siglip-base-patch16-224": "https://huggingface.co/google/siglip-base-patch16-224/resolve/main/spiece.model",
43
+ }
44
+ }
45
+
46
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
47
+ "google/siglip-base-patch16-224": 256,
48
+ }
49
+
50
+ SPIECE_UNDERLINE = "▁"
51
+
52
+
53
+ class SiglipTokenizer(PreTrainedTokenizer):
54
+ """
55
+ Construct a Siglip tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
56
+
57
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
58
+ this superclass for more information regarding those methods.
59
+
60
+ Args:
61
+ vocab_file (`str`):
62
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
63
+ contains the vocabulary necessary to instantiate a tokenizer.
64
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
65
+ The end of sequence token.
66
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
67
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
68
+ token instead.
69
+ pad_token (`str`, *optional*, defaults to `"</s>"`):
70
+ The token used for padding, for example when batching sequences of different lengths.
71
+ additional_special_tokens (`List[str]`, *optional*):
72
+ Additional special tokens used by the tokenizer.
73
+ sp_model_kwargs (`dict`, *optional*):
74
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
75
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
76
+ to set:
77
+
78
+ - `enable_sampling`: Enable subword regularization.
79
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
80
+
81
+ - `nbest_size = {0,1}`: No sampling is performed.
82
+ - `nbest_size > 1`: samples from the nbest_size results.
83
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
84
+ using forward-filtering-and-backward-sampling algorithm.
85
+
86
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
87
+ BPE-dropout.
88
+ model_max_length (`int`, *optional*, defaults to 64):
89
+ The maximum length (in number of tokens) for model inputs.
90
+ do_lower_case (`bool`, *optional*, defaults to `True`):
91
+ Whether or not to lowercase the input when tokenizing.
92
+ """
93
+
94
+ vocab_files_names = VOCAB_FILES_NAMES
95
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
96
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
97
+ model_input_names = ["input_ids", "attention_mask"]
98
+
99
+ def __init__(
100
+ self,
101
+ vocab_file,
102
+ eos_token="</s>",
103
+ unk_token="<unk>",
104
+ pad_token="</s>",
105
+ additional_special_tokens=None,
106
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
107
+ model_max_length=64,
108
+ do_lower_case=True,
109
+ **kwargs,
110
+ ) -> None:
111
+ requires_backends(self, "protobuf")
112
+
113
+ pad_token = (
114
+ AddedToken(pad_token, rstrip=True, lstrip=True, normalized=False, special=True)
115
+ if isinstance(pad_token, str)
116
+ else pad_token
117
+ )
118
+ unk_token = (
119
+ AddedToken(unk_token, rstrip=True, lstrip=True, normalized=False, special=True)
120
+ if isinstance(unk_token, str)
121
+ else unk_token
122
+ )
123
+ eos_token = (
124
+ AddedToken(eos_token, rstrip=True, lstrip=True, normalized=False, special=True)
125
+ if isinstance(eos_token, str)
126
+ else eos_token
127
+ )
128
+
129
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
130
+
131
+ self.do_lower_case = do_lower_case
132
+ self.vocab_file = vocab_file
133
+
134
+ self.sp_model = self.get_spm_processor()
135
+ self.vocab_file = vocab_file
136
+
137
+ super().__init__(
138
+ eos_token=eos_token,
139
+ unk_token=unk_token,
140
+ pad_token=pad_token,
141
+ additional_special_tokens=additional_special_tokens,
142
+ sp_model_kwargs=self.sp_model_kwargs,
143
+ model_max_length=model_max_length,
144
+ do_lower_case=do_lower_case,
145
+ **kwargs,
146
+ )
147
+
148
+ def get_spm_processor(self):
149
+ tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
150
+ with open(self.vocab_file, "rb") as f:
151
+ sp_model = f.read()
152
+ model_pb2 = import_protobuf()
153
+ model = model_pb2.ModelProto.FromString(sp_model)
154
+ normalizer_spec = model_pb2.NormalizerSpec()
155
+ normalizer_spec.add_dummy_prefix = False
156
+ model.normalizer_spec.MergeFrom(normalizer_spec)
157
+ sp_model = model.SerializeToString()
158
+ tokenizer.LoadFromSerializedProto(sp_model)
159
+ return tokenizer
160
+
161
+ @property
162
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.vocab_size
163
+ def vocab_size(self):
164
+ return self.sp_model.get_piece_size()
165
+
166
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_vocab
167
+ def get_vocab(self):
168
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
169
+ vocab.update(self.added_tokens_encoder)
170
+ return vocab
171
+
172
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_special_tokens_mask
173
+ def get_special_tokens_mask(
174
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
175
+ ) -> List[int]:
176
+ """
177
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
178
+ special tokens using the tokenizer `prepare_for_model` method.
179
+
180
+ Args:
181
+ token_ids_0 (`List[int]`):
182
+ List of IDs.
183
+ token_ids_1 (`List[int]`, *optional*):
184
+ Optional second list of IDs for sequence pairs.
185
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
186
+ Whether or not the token list is already formatted with special tokens for the model.
187
+
188
+ Returns:
189
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
190
+ """
191
+ if already_has_special_tokens:
192
+ return super().get_special_tokens_mask(
193
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
194
+ )
195
+
196
+ # normal case: some special tokens
197
+ if token_ids_1 is None:
198
+ return ([0] * len(token_ids_0)) + [1]
199
+ return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
200
+
201
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._add_eos_if_not_present
202
+ def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]:
203
+ """Do not add eos again if user already added it."""
204
+ if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id:
205
+ warnings.warn(
206
+ f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated"
207
+ " eos tokens being added."
208
+ )
209
+ return token_ids
210
+ else:
211
+ return token_ids + [self.eos_token_id]
212
+
213
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.create_token_type_ids_from_sequences
214
+ def create_token_type_ids_from_sequences(
215
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
216
+ ) -> List[int]:
217
+ """
218
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
219
+ use of token type ids, therefore a list of zeros is returned.
220
+
221
+ Args:
222
+ token_ids_0 (`List[int]`):
223
+ List of IDs.
224
+ token_ids_1 (`List[int]`, *optional*):
225
+ Optional second list of IDs for sequence pairs.
226
+
227
+ Returns:
228
+ `List[int]`: List of zeros.
229
+ """
230
+ eos = [self.eos_token_id]
231
+
232
+ if token_ids_1 is None:
233
+ return len(token_ids_0 + eos) * [0]
234
+ return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
235
+
236
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.build_inputs_with_special_tokens
237
+ def build_inputs_with_special_tokens(
238
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
239
+ ) -> List[int]:
240
+ """
241
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
242
+ adding special tokens. A sequence has the following format:
243
+
244
+ - single sequence: `X </s>`
245
+ - pair of sequences: `A </s> B </s>`
246
+
247
+ Args:
248
+ token_ids_0 (`List[int]`):
249
+ List of IDs to which the special tokens will be added.
250
+ token_ids_1 (`List[int]`, *optional*):
251
+ Optional second list of IDs for sequence pairs.
252
+
253
+ Returns:
254
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
255
+ """
256
+ token_ids_0 = self._add_eos_if_not_present(token_ids_0)
257
+ if token_ids_1 is None:
258
+ return token_ids_0
259
+ else:
260
+ token_ids_1 = self._add_eos_if_not_present(token_ids_1)
261
+ return token_ids_0 + token_ids_1
262
+
263
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__getstate__
264
+ def __getstate__(self):
265
+ state = self.__dict__.copy()
266
+ state["sp_model"] = None
267
+ return state
268
+
269
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__setstate__
270
+ def __setstate__(self, d):
271
+ self.__dict__ = d
272
+
273
+ # for backward compatibility
274
+ if not hasattr(self, "sp_model_kwargs"):
275
+ self.sp_model_kwargs = {}
276
+
277
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
278
+ self.sp_model.Load(self.vocab_file)
279
+
280
+ def remove_punctuation(self, text: str) -> str:
281
+ return text.translate(str.maketrans("", "", string.punctuation))
282
+
283
+ # source: https://github.com/google-research/big_vision/blob/3b8e5ab6ad4f96e32b32826f9e1b8fd277914f9c/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94
284
+ def canonicalize_text(self, text, *, keep_punctuation_exact_string=None):
285
+ """Returns canonicalized `text` (puncuation removed).
286
+
287
+ Args:
288
+ text (`str`):
289
+ String to be canonicalized.
290
+ keep_punctuation_exact_string (`str`, *optional*):
291
+ If provided, then this exact string is kept. For example providing '{}' will keep any occurrences of '{}'
292
+ (but will still remove '{' and '}' that appear separately).
293
+ """
294
+ if keep_punctuation_exact_string:
295
+ text = keep_punctuation_exact_string.join(
296
+ self.remove_punctuation(part) for part in text.split(keep_punctuation_exact_string)
297
+ )
298
+ else:
299
+ text = self.remove_punctuation(text)
300
+ text = re.sub(r"\s+", " ", text)
301
+ text = text.strip()
302
+
303
+ return text
304
+
305
+ def tokenize(self, text: "TextInput", add_special_tokens=False, **kwargs) -> List[str]:
306
+ """
307
+ Converts a string to a list of tokens.
308
+ """
309
+ tokens = super().tokenize(SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, " "), **kwargs)
310
+
311
+ if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
312
+ tokens = tokens[1:]
313
+ return tokens
314
+
315
+ @property
316
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.unk_token_length
317
+ def unk_token_length(self):
318
+ return len(self.sp_model.encode(str(self.unk_token)))
319
+
320
+ def _tokenize(self, text, **kwargs):
321
+ """
322
+ Returns a tokenized string.
323
+
324
+ We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
325
+ SPIECE_UNDERLINE.
326
+
327
+ For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give `['H', 'e', 'y']` instead of `['▁He', 'y']`.
328
+
329
+ Thus we always encode `f"{unk_token}text"` and strip the `unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
330
+ `self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
331
+ """
332
+ text = self.canonicalize_text(text, keep_punctuation_exact_string=None)
333
+ tokens = self.sp_model.encode(text, out_type=str)
334
+
335
+ # 1. Encode string + prefix ex: "<unk> Hey"
336
+ tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
337
+ # 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
338
+ return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
339
+
340
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._convert_token_to_id
341
+ def _convert_token_to_id(self, token):
342
+ """Converts a token (str) in an id using the vocab."""
343
+ return self.sp_model.piece_to_id(token)
344
+
345
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._convert_id_to_token
346
+ def _convert_id_to_token(self, index):
347
+ """Converts an index (integer) in a token (str) using the vocab."""
348
+ token = self.sp_model.IdToPiece(index)
349
+ return token
350
+
351
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.convert_tokens_to_string
352
+ def convert_tokens_to_string(self, tokens):
353
+ """Converts a sequence of tokens (string) in a single string."""
354
+ current_sub_tokens = []
355
+ # since we manually add the prefix space, we have to remove it
356
+ tokens[0] = tokens[0].lstrip(SPIECE_UNDERLINE)
357
+ out_string = ""
358
+ prev_is_special = False
359
+ for token in tokens:
360
+ # make sure that special tokens are not decoded using sentencepiece model
361
+ if token in self.all_special_tokens:
362
+ if not prev_is_special:
363
+ out_string += " "
364
+ out_string += self.sp_model.decode(current_sub_tokens) + token
365
+ prev_is_special = True
366
+ current_sub_tokens = []
367
+ else:
368
+ current_sub_tokens.append(token)
369
+ prev_is_special = False
370
+ out_string += self.sp_model.decode(current_sub_tokens)
371
+ return out_string.strip()
372
+
373
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.save_vocabulary
374
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
375
+ if not os.path.isdir(save_directory):
376
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
377
+ return
378
+ out_vocab_file = os.path.join(
379
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
380
+ )
381
+
382
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
383
+ copyfile(self.vocab_file, out_vocab_file)
384
+ elif not os.path.isfile(self.vocab_file):
385
+ with open(out_vocab_file, "wb") as fi:
386
+ content_spiece_model = self.sp_model.serialized_model_proto()
387
+ fi.write(content_spiece_model)
388
+
389
+ return (out_vocab_file,)
dam/model/multimodal_encoder/siglip_encoder.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import torch # noqa
18
+ from .vision_encoder import VisionTower
19
+
20
+ from transformers import AutoConfig, PretrainedConfig, AutoModel
21
+ from .siglip import (
22
+ SiglipVisionConfig,
23
+ SiglipVisionModel,
24
+ SiglipImageProcessor,
25
+ )
26
+
27
+
28
+ class SiglipVisionTower(VisionTower):
29
+ def __init__(self, model_name_or_path: str, config: PretrainedConfig, state_dict=None):
30
+ super().__init__(model_name_or_path, config)
31
+ self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path)
32
+ self.vision_tower = SiglipVisionModel.from_pretrained(
33
+ # TODO(ligeng): why pass config here leading to errors?
34
+ model_name_or_path, torch_dtype=eval(config.model_dtype), state_dict=state_dict
35
+ )
36
+ self.is_loaded = True
37
+
38
+
39
+ AutoConfig.register("siglip_vision_model", SiglipVisionConfig, exist_ok=True)
40
+ AutoModel.register(SiglipVisionConfig, SiglipVisionModel, exist_ok=True)
41
+
dam/model/multimodal_encoder/vision_encoder.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
18
+
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ from accelerate.hooks import add_hook_to_module
23
+ from transformers import PreTrainedModel
24
+ from transformers.image_processing_utils import BaseImageProcessor
25
+
26
+ class VisionTower(nn.Module):
27
+ def __init__(self, vision_tower, args, delay_load=False):
28
+ super().__init__()
29
+
30
+ self.is_loaded = False
31
+
32
+ self.vision_tower_name = vision_tower
33
+ self.select_layer = getattr(args, "mm_vision_select_layer", -2)
34
+ self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
35
+
36
+ self.cfg_only = None
37
+
38
+ def feature_select(self, image_forward_outs):
39
+ image_features = image_forward_outs.hidden_states[self.select_layer]
40
+ if self.select_feature == "patch":
41
+ image_features = image_features[:, 1:]
42
+ elif self.select_feature == "cls_patch":
43
+ image_features = image_features
44
+ else:
45
+ raise ValueError(f"Unexpected select feature: {self.select_feature}")
46
+ return image_features
47
+
48
+ def _maybe_resize_pos_embeds(
49
+ self,
50
+ model: PreTrainedModel,
51
+ image_processor: BaseImageProcessor,
52
+ resolution: int = -1,
53
+ interpolate_mode: str = "linear",
54
+ ):
55
+ if resolution in [model.config.image_size, -1]:
56
+ return
57
+ print(f"Resizing vision model's position embeddings to support higher vision resolution: from {model.config.image_size} to {resolution} ...")
58
+ embeddings = model.vision_model.embeddings
59
+ patch_size = embeddings.patch_size
60
+ num_new_tokens = int((resolution // patch_size) ** 2)
61
+
62
+ old_embeddings = embeddings.position_embedding
63
+ match interpolate_mode:
64
+ case "linear":
65
+ ## Step 1: Calculate the corresponding patch ID (pid) in the current resolution (M patches) based on the target resolution (N patches). Formula: pid = pid / N * M
66
+ ## Step 2: Obtain new embeddings by interpolating between the embeddings of the two nearest calculated patch IDs. Formula: new_embeds = (pid - floor(pid)) * embeds[ceil(pid)] + (ceil(pid) - pid) * embeds[floor(pid)]
67
+
68
+ # At inference time, we assume deepspeed zero3 is not enabled.
69
+ # import deepspeed
70
+
71
+ # with deepspeed.zero.GatheredParameters([old_embeddings.weight], modifier_rank=None):
72
+ # old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
73
+ old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
74
+
75
+ new_embeddings = nn.Embedding(
76
+ num_new_tokens,
77
+ old_embedding_dim,
78
+ dtype=old_embeddings.weight.dtype,
79
+ device=old_embeddings.weight.device,
80
+ )
81
+ mapped_indices = (
82
+ torch.arange(num_new_tokens).to(old_embeddings.weight.device)
83
+ / (num_new_tokens - 1)
84
+ * (old_num_tokens - 1)
85
+ )
86
+ floor_indices = torch.clamp(mapped_indices.floor().long(), min=0, max=old_num_tokens - 1)
87
+ ceil_indices = torch.clamp(mapped_indices.ceil().long(), min=0, max=old_num_tokens - 1)
88
+
89
+ # At inference time, we assume deepspeed zero3 is not enabled.
90
+ # params = [old_embeddings.weight, new_embeddings.weight]
91
+ # with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
92
+ # interpolated_embeds = (mapped_indices - floor_indices)[:, None] * old_embeddings.weight.data[
93
+ # ceil_indices, :
94
+ # ] + (ceil_indices - mapped_indices)[:, None] * old_embeddings.weight.data[floor_indices, :]
95
+
96
+ interpolated_embeds = (mapped_indices - floor_indices)[:, None] * old_embeddings.weight.data[
97
+ ceil_indices, :
98
+ ] + (ceil_indices - mapped_indices)[:, None] * old_embeddings.weight.data[floor_indices, :]
99
+
100
+ new_embeddings.weight.data = interpolated_embeds
101
+ case _:
102
+ raise NotImplementedError
103
+
104
+ if hasattr(old_embeddings, "_hf_hook"):
105
+ hook = old_embeddings._hf_hook
106
+ add_hook_to_module(new_embeddings, hook)
107
+ new_embeddings.requires_grad_(old_embeddings.weight.requires_grad)
108
+ ## update vision encoder's configurations
109
+ model.config.image_size = resolution
110
+ if hasattr(image_processor, "crop_size"):
111
+ # CLIP vision tower
112
+ image_processor.crop_size = resolution
113
+ else:
114
+ # SIGLIP vision tower
115
+ assert hasattr(image_processor, "size")
116
+ image_processor.size = {"height": resolution, "width": resolution}
117
+ ## TODO define a '_reinitialize' method for VisionTower
118
+ embeddings.position_embedding = new_embeddings
119
+ embeddings.image_size = resolution
120
+ embeddings.num_patches = embeddings.num_positions = num_new_tokens
121
+ embeddings.position_ids = (
122
+ torch.arange(embeddings.num_positions).expand((1, -1)).to(old_embeddings.weight.device)
123
+ )
124
+
125
+ def forward(self, images, **kwargs):
126
+ if type(images) is list:
127
+ image_features = []
128
+ for image in images:
129
+ image_forward_out = self.vision_tower(
130
+ image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
131
+ output_hidden_states=True, **kwargs,
132
+ )
133
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
134
+ image_features.append(image_feature)
135
+ else:
136
+ image_forward_outs = self.vision_tower(
137
+ images.to(device=self.device, dtype=self.dtype),
138
+ output_hidden_states=True, **kwargs,
139
+ )
140
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
141
+
142
+ return image_features
143
+
144
+ @property
145
+ def dummy_feature(self):
146
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
147
+
148
+ @property
149
+ def dtype(self):
150
+ return self.vision_tower.dtype
151
+
152
+ @property
153
+ def device(self):
154
+ return self.vision_tower.device
155
+
156
+ @property
157
+ def config(self):
158
+ if self.is_loaded:
159
+ return self.vision_tower.config
160
+ else:
161
+ return self.cfg_only
162
+
163
+ @property
164
+ def hidden_size(self):
165
+ return self.config.hidden_size
166
+
167
+ @property
168
+ def num_patches(self):
169
+ return (self.config.image_size // self.config.patch_size) ** 2
dam/model/multimodal_projector/base_projector.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import torch.nn as nn
18
+ import re
19
+ import torch
20
+ from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
21
+
22
+
23
+ class IdentityMap(nn.Module):
24
+ def __init__(self):
25
+ super().__init__()
26
+
27
+ def forward(self, x, *args, **kwargs):
28
+ return x
29
+
30
+ @property
31
+ def config(self):
32
+ return {"mm_projector_type": "identity"}
33
+
34
+
35
+ class SimpleResBlock(nn.Module):
36
+ def __init__(self, channels):
37
+ super().__init__()
38
+ self.pre_norm = nn.LayerNorm(channels)
39
+
40
+ self.proj = nn.Sequential(
41
+ nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)
42
+ )
43
+
44
+ def forward(self, x):
45
+ x = self.pre_norm(x)
46
+ return x + self.proj(x)
47
+
48
+
49
+ class DownSampleBlock(nn.Module):
50
+
51
+ def forward(self, x):
52
+ vit_embeds = x
53
+ h = w = int(vit_embeds.shape[1] ** 0.5)
54
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
55
+ vit_embeds = self.flat_square(vit_embeds)
56
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
57
+ return vit_embeds
58
+
59
+ def flat_square(self, x):
60
+ n, w, h, c = x.size()
61
+ if w % 2 == 1:
62
+ x = torch.concat([x, torch.zeros((n, 1, h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
63
+ n, w, h, c = x.size()
64
+ if h % 2 == 1:
65
+ x = torch.concat([x, torch.zeros((n, w, 1, c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
66
+ n, w, h, c = x.size()
67
+ x = x.view(n, w, int(h / 2), int(c * 2))
68
+ x = x.permute(0, 2, 1, 3).contiguous()
69
+ x = x.view(n, int(h / 2), int(w / 2), int(c * 4))
70
+ return x
71
+
72
+ class MultimodalProjectorConfig(PretrainedConfig):
73
+ model_type = "v2l_projector"
74
+
75
+ def __init__(self, mm_projector_type: str=None, **kwargs):
76
+ super().__init__()
77
+ self.mm_projector_type = mm_projector_type
78
+
79
+
80
+ class MultimodalProjector(PreTrainedModel):
81
+ config_class = MultimodalProjectorConfig
82
+
83
+ def __init__(
84
+ self, mm_projector_cfg: MultimodalProjectorConfig, config: PretrainedConfig
85
+ ):
86
+ super().__init__(mm_projector_cfg)
87
+ mm_projector_type = mm_projector_cfg.mm_projector_type
88
+ if mm_projector_type == "identity":
89
+ self.layers = IdentityMap()
90
+ elif mm_projector_type == "linear":
91
+ self.layers = nn.Linear(config.mm_hidden_size, config.hidden_size)
92
+ elif mm_projector_type == "mlp_downsample":
93
+ self.layers = nn.Sequential(
94
+ DownSampleBlock(),
95
+ nn.LayerNorm(config.mm_hidden_size * 4),
96
+ nn.Linear(config.mm_hidden_size * 4, config.hidden_size),
97
+ nn.GELU(),
98
+ nn.Linear(config.hidden_size, config.hidden_size)
99
+ )
100
+ else:
101
+ mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", mm_projector_type)
102
+ if mlp_gelu_match:
103
+ mlp_depth = int(mlp_gelu_match.group(1))
104
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
105
+ for _ in range(1, mlp_depth):
106
+ modules.append(nn.GELU())
107
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
108
+ self.layers = nn.Sequential(*modules)
109
+ else:
110
+ raise ValueError(f"Unknown projector type: {mm_projector_type}")
111
+
112
+ def forward(self, x, *args, **kwargs):
113
+ return self.layers(x)
114
+
115
+ AutoConfig.register("v2l_projector", MultimodalProjectorConfig)
116
+ AutoModel.register(MultimodalProjectorConfig, MultimodalProjector)
dam/model/multimodal_projector/builder.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
18
+
19
+ import torch # noqa
20
+ import os
21
+
22
+ from .base_projector import MultimodalProjectorConfig, MultimodalProjector
23
+ from transformers import PretrainedConfig, PreTrainedModel
24
+
25
+
26
+ def build_mm_projector(
27
+ model_type_or_path: str, config: PretrainedConfig
28
+ ) -> PreTrainedModel:
29
+ if model_type_or_path is None:
30
+ return None
31
+
32
+ ## load from pretrained model
33
+ if config.resume_path:
34
+ assert os.path.exists(
35
+ model_type_or_path
36
+ ), f"Resume mm projector path {model_type_or_path} does not exist!"
37
+ return MultimodalProjector.from_pretrained(
38
+ model_type_or_path, config, torch_dtype=eval(config.model_dtype)
39
+ )
40
+ ## build from scratch
41
+ else:
42
+ mm_projector_cfg = MultimodalProjectorConfig(model_type_or_path)
43
+ mm_projector = MultimodalProjector(mm_projector_cfg, config).to(
44
+ eval(config.model_dtype)
45
+ )
46
+ return mm_projector
dam/model/utils.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
17
+ import os
18
+ import os.path as osp
19
+ from transformers import AutoConfig
20
+ from transformers import PretrainedConfig
21
+ from huggingface_hub import snapshot_download, repo_exists
22
+ from huggingface_hub.utils import HFValidationError
23
+
24
+ def get_model_config(config):
25
+ # `mask_encoder_cfg` and `context_provider_cfg` are optional
26
+ default_keys = ["llm_cfg", "vision_tower_cfg", "mm_projector_cfg", "mask_encoder_cfg", "context_provider_cfg"]
27
+
28
+ if hasattr(config, "_name_or_path") and len(config._name_or_path) >= 2:
29
+ root_path = config._name_or_path
30
+ else:
31
+ root_path = config.resume_path
32
+
33
+ # download from huggingface
34
+ if root_path is not None and not osp.exists(root_path):
35
+ try:
36
+ valid_hf_repo = repo_exists(root_path)
37
+ except HFValidationError:
38
+ valid_hf_repo = False
39
+ if valid_hf_repo:
40
+ root_path = snapshot_download(root_path)
41
+
42
+ return_list = []
43
+ for key in default_keys:
44
+ cfg = getattr(config, key, None)
45
+ if isinstance(cfg, dict):
46
+ try:
47
+ return_list.append(os.path.join(root_path, key[:-4]))
48
+ except:
49
+ raise ValueError(f"Cannot find resume path in config for {key}!")
50
+ elif isinstance(cfg, PretrainedConfig):
51
+ return_list.append(os.path.join(root_path, key[:-4]))
52
+ elif isinstance(cfg, str):
53
+ return_list.append(cfg)
54
+ elif cfg is None:
55
+ # We still return even if the cfg is None or does not exist
56
+ return_list.append(cfg)
57
+
58
+ return return_list
59
+
60
+
61
+ def is_mm_model(model_path):
62
+ """
63
+ Check if the model at the given path is a visual language model.
64
+
65
+ Args:
66
+ model_path (str): The path to the model.
67
+
68
+ Returns:
69
+ bool: True if the model is an MM model, False otherwise.
70
+ """
71
+ config = AutoConfig.from_pretrained(model_path)
72
+ architectures = config.architectures
73
+ for architecture in architectures:
74
+ if "llava" in architecture.lower():
75
+ return True
76
+ return False
77
+
78
+
79
+ def auto_upgrade(config):
80
+ cfg = AutoConfig.from_pretrained(config)
81
+ if "llava" in config and "llava" not in cfg.model_type:
82
+ assert cfg.model_type == "llama"
83
+ print(
84
+ "You are using newer LLaVA code base, while the checkpoint of v0 is from older code base."
85
+ )
86
+ print(
87
+ "You must upgrade the checkpoint to the new code base (this can be done automatically)."
88
+ )
89
+ confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
90
+ if confirm.lower() in ["y", "yes"]:
91
+ print("Upgrading checkpoint...")
92
+ assert len(cfg.architectures) == 1
93
+ setattr(cfg.__class__, "model_type", "llava")
94
+ cfg.architectures[0] = "LlavaLlamaForCausalLM"
95
+ cfg.save_pretrained(config)
96
+ print("Checkpoint upgraded.")
97
+ else:
98
+ print("Checkpoint upgrade aborted.")
99
+ exit(1)