English
naveensp commited on
Commit
3e3e14c
·
verified ·
1 Parent(s): 43e609b

Delete mm_utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. mm_utils.py +0 -247
mm_utils.py DELETED
@@ -1,247 +0,0 @@
1
- from PIL import Image
2
- from io import BytesIO
3
- import base64
4
- import torch
5
- import math
6
- import ast
7
-
8
- from transformers import StoppingCriteria
9
- from llava.constants import IMAGE_TOKEN_INDEX
10
-
11
-
12
- def select_best_resolution(original_size, possible_resolutions):
13
- """
14
- Selects the best resolution from a list of possible resolutions based on the original size.
15
-
16
- Args:
17
- original_size (tuple): The original size of the image in the format (width, height).
18
- possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
19
-
20
- Returns:
21
- tuple: The best fit resolution in the format (width, height).
22
- """
23
- original_width, original_height = original_size
24
- best_fit = None
25
- max_effective_resolution = 0
26
- min_wasted_resolution = float('inf')
27
-
28
- for width, height in possible_resolutions:
29
- scale = min(width / original_width, height / original_height)
30
- downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
31
- effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
32
- wasted_resolution = (width * height) - effective_resolution
33
-
34
- if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
35
- max_effective_resolution = effective_resolution
36
- min_wasted_resolution = wasted_resolution
37
- best_fit = (width, height)
38
-
39
- return best_fit
40
-
41
-
42
- def resize_and_pad_image(image, target_resolution):
43
- """
44
- Resize and pad an image to a target resolution while maintaining aspect ratio.
45
-
46
- Args:
47
- image (PIL.Image.Image): The input image.
48
- target_resolution (tuple): The target resolution (width, height) of the image.
49
-
50
- Returns:
51
- PIL.Image.Image: The resized and padded image.
52
- """
53
- original_width, original_height = image.size
54
- target_width, target_height = target_resolution
55
-
56
- scale_w = target_width / original_width
57
- scale_h = target_height / original_height
58
-
59
- if scale_w < scale_h:
60
- new_width = target_width
61
- new_height = min(math.ceil(original_height * scale_w), target_height)
62
- else:
63
- new_height = target_height
64
- new_width = min(math.ceil(original_width * scale_h), target_width)
65
-
66
- # Resize the image
67
- resized_image = image.resize((new_width, new_height))
68
-
69
- new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
70
- paste_x = (target_width - new_width) // 2
71
- paste_y = (target_height - new_height) // 2
72
- new_image.paste(resized_image, (paste_x, paste_y))
73
-
74
- return new_image
75
-
76
-
77
- def divide_to_patches(image, patch_size):
78
- """
79
- Divides an image into patches of a specified size.
80
-
81
- Args:
82
- image (PIL.Image.Image): The input image.
83
- patch_size (int): The size of each patch.
84
-
85
- Returns:
86
- list: A list of PIL.Image.Image objects representing the patches.
87
- """
88
- patches = []
89
- width, height = image.size
90
- for i in range(0, height, patch_size):
91
- for j in range(0, width, patch_size):
92
- box = (j, i, j + patch_size, i + patch_size)
93
- patch = image.crop(box)
94
- patches.append(patch)
95
-
96
- return patches
97
-
98
-
99
- def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
100
- """
101
- Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
102
-
103
- Args:
104
- image_size (tuple): The size of the input image in the format (width, height).
105
- grid_pinpoints (str): A string representation of a list of possible resolutions.
106
- patch_size (int): The size of each image patch.
107
-
108
- Returns:
109
- tuple: The shape of the image patch grid in the format (width, height).
110
- """
111
- if type(grid_pinpoints) is list:
112
- possible_resolutions = grid_pinpoints
113
- else:
114
- possible_resolutions = ast.literal_eval(grid_pinpoints)
115
- width, height = select_best_resolution(image_size, possible_resolutions)
116
- return width // patch_size, height // patch_size
117
-
118
-
119
- def process_anyres_image(image, processor, grid_pinpoints):
120
- """
121
- Process an image with variable resolutions.
122
-
123
- Args:
124
- image (PIL.Image.Image): The input image to be processed.
125
- processor: The image processor object.
126
- grid_pinpoints (str): A string representation of a list of possible resolutions.
127
-
128
- Returns:
129
- torch.Tensor: A tensor containing the processed image patches.
130
- """
131
- if type(grid_pinpoints) is list:
132
- possible_resolutions = grid_pinpoints
133
- else:
134
- possible_resolutions = ast.literal_eval(grid_pinpoints)
135
- best_resolution = select_best_resolution(image.size, possible_resolutions)
136
- image_padded = resize_and_pad_image(image, best_resolution)
137
-
138
- patches = divide_to_patches(image_padded, processor.crop_size['height'])
139
-
140
- image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
141
-
142
- image_patches = [image_original_resize] + patches
143
- image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
144
- for image_patch in image_patches]
145
- return torch.stack(image_patches, dim=0)
146
-
147
-
148
- def load_image_from_base64(image):
149
- return Image.open(BytesIO(base64.b64decode(image)))
150
-
151
-
152
- def expand2square(pil_img, background_color):
153
- width, height = pil_img.size
154
- if width == height:
155
- return pil_img
156
- elif width > height:
157
- result = Image.new(pil_img.mode, (width, width), background_color)
158
- result.paste(pil_img, (0, (width - height) // 2))
159
- return result
160
- else:
161
- result = Image.new(pil_img.mode, (height, height), background_color)
162
- result.paste(pil_img, ((height - width) // 2, 0))
163
- return result
164
-
165
-
166
- def process_images(images, image_processor, model_cfg):
167
- image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
168
- new_images = []
169
- if image_aspect_ratio == 'pad':
170
- for image in images:
171
- image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
172
- image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
173
- new_images.append(image)
174
- elif image_aspect_ratio == "anyres":
175
- for image in images:
176
- image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
177
- new_images.append(image)
178
- else:
179
- return image_processor(images, return_tensors='pt')['pixel_values']
180
- if all(x.shape == new_images[0].shape for x in new_images):
181
- new_images = torch.stack(new_images, dim=0)
182
- return new_images
183
-
184
-
185
- def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
186
- prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
187
-
188
- def insert_separator(X, sep):
189
- return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
190
-
191
- input_ids = []
192
- offset = 0
193
- if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
194
- offset = 1
195
- input_ids.append(prompt_chunks[0][0])
196
-
197
- for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
198
- input_ids.extend(x[offset:])
199
-
200
- if return_tensors is not None:
201
- if return_tensors == 'pt':
202
- return torch.tensor(input_ids, dtype=torch.long)
203
- raise ValueError(f'Unsupported tensor type: {return_tensors}')
204
- return input_ids
205
-
206
-
207
- def get_model_name_from_path(model_path):
208
- model_path = model_path.strip("/")
209
- model_paths = model_path.split("/")
210
- if model_paths[-1].startswith('checkpoint-'):
211
- return model_paths[-2] + "_" + model_paths[-1]
212
- else:
213
- return model_paths[-1]
214
-
215
- class KeywordsStoppingCriteria(StoppingCriteria):
216
- def __init__(self, keywords, tokenizer, input_ids):
217
- self.keywords = keywords
218
- self.keyword_ids = []
219
- self.max_keyword_len = 0
220
- for keyword in keywords:
221
- cur_keyword_ids = tokenizer(keyword).input_ids
222
- if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
223
- cur_keyword_ids = cur_keyword_ids[1:]
224
- if len(cur_keyword_ids) > self.max_keyword_len:
225
- self.max_keyword_len = len(cur_keyword_ids)
226
- self.keyword_ids.append(torch.tensor(cur_keyword_ids))
227
- self.tokenizer = tokenizer
228
- self.start_len = input_ids.shape[1]
229
-
230
- def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
231
- offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
232
- self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
233
- for keyword_id in self.keyword_ids:
234
- truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
235
- if torch.equal(truncated_output_ids, keyword_id):
236
- return True
237
- outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
238
- for keyword in self.keywords:
239
- if keyword in outputs:
240
- return True
241
- return False
242
-
243
- def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
244
- outputs = []
245
- for i in range(output_ids.shape[0]):
246
- outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
247
- return all(outputs)