File size: 16,377 Bytes
64c250f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 | from PIL import Image
from io import BytesIO
import base64
import math
import ast
import re
import torch
from transformers import StoppingCriteria
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
GANDALF_TOKEN_INDEX = -300
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "</s>"
DEFAULT_UNK_TOKEN = "<unk>"
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
DEFAULT_VIDEO_TOKEN = "<video>"
DEFAULT_VIDEO_FRAME_TOKEN = "<vi_frame>"
DEFAULT_VI_START_TOKEN = "<vi_start>"
DEFAULT_VI_END_TOKEN = "<vi_end>"
DEFAULT_EOC_TOKEN = "<eoc>"
COR_START_TOKEN = "<cor>"
COR_END_TOKEN = "<\cor>"
SEQ_MAX_LEN = 50000
BLACK_IMG_ENV = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x03\x00\x00\x00\x03\x08\x02\x00\x00\x00\xd9J"\xe8\x00\x00\x00\x12IDAT\x08\x1dcd\x80\x01F\x06\x18`d\x80\x01\x00\x00Z\x00\x04we\x03N\x00\x00\x00\x00IEND\xaeB`\x82'
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Args:
image_size (tuple): The size of the input image in the format (width, height).
grid_pinpoints (str): A string representation of a list of possible resolutions.
patch_size (int): The size of each image patch.
Returns:
tuple: The shape of the image patch grid in the format (width, height).
"""
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
# Use regex to extract the range from the input string
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
range_start = tuple(map(int, matches[0]))
range_end = tuple(map(int, matches[-1]))
# Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
grid_pinpoints = [
(i, j)
for i in range(range_start[0], range_end[0] + 1)
for j in range(range_start[1], range_end[1] + 1)
]
# Multiply all elements by patch_size
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints)
width, height = select_best_resolution(image_size, possible_resolutions)
return width // patch_size, height // patch_size
def select_best_resolution(original_size, possible_resolutions):
"""
Selects the best resolution from a list of possible resolutions based on the original size.
Args:
original_size (tuple): The original size of the image in the format (width, height).
possible_resolutions (list): A list of possible resolutions in the format
[(width1, height1), (width2, height2), ...].
Returns:
tuple: The best fit resolution in the format (width, height).
"""
original_width, original_height = original_size
best_fit = None
max_effective_resolution = 0
min_wasted_resolution = float("inf")
for width, height in possible_resolutions:
# Calculate the downscaled size to keep the aspect ratio
scale = min(width / original_width, height / original_height)
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
# Calculate effective and wasted resolutions
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
wasted_resolution = (width * height) - effective_resolution
if effective_resolution > max_effective_resolution or \
(effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
max_effective_resolution = effective_resolution
min_wasted_resolution = wasted_resolution
best_fit = (width, height)
return best_fit
def unpad_image(tensor, original_size):
"""
Unpads a PyTorch tensor of a padded and resized image.
Args:
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
original_size (tuple): The original size of the image (height, width).
Returns:
torch.Tensor: The unpadded image tensor.
"""
original_width, original_height = original_size
current_height, current_width = tensor.shape[1:]
# Compute aspect ratios
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
# Determine padding size and direction
if original_aspect_ratio > current_aspect_ratio:
# Padding was added to the height
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding: current_height - padding, :]
else:
# Padding was added to the width
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding: current_width - padding]
return unpadded_tensor
def process_anyres_image(image, processor, grid_pinpoints):
"""
Process an image with variable resolutions.
Args:
image (PIL.Image.Image): The input image to be processed.
processor: The image processor object.
grid_pinpoints (str): A string representation of a list of possible resolutions.
Returns:
torch.Tensor: A tensor containing the processed image patches.
"""
# Convert grid_pinpoints from string to list
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
try:
patch_size = processor.size["height"]
except Exception:
patch_size = processor.size["shortest_edge"]
assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
# Use regex to extract the range from the input string
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
range_start = tuple(map(int, matches[0]))
range_end = tuple(map(int, matches[-1]))
# Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
grid_pinpoints = [
(i, j)
for i in range(range_start[0], range_end[0] + 1)
for j in range(range_start[1], range_end[1] + 1)
]
# Multiply all elements by patch_size
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints)
best_resolution = select_best_resolution(image.size, possible_resolutions)
image_padded = resize_and_pad_image(image, best_resolution)
patches = divide_to_patches(image_padded, processor.size["height"])
# FIXME: this seems to be a bug that it resizes instead of pad.
# but to keep it consistent with previous, i will keep it as it is
# TODO: uncomment below to ablate with the padding
if isinstance(processor.size, dict):
shortest_edge = processor.size["height"]
else:
shortest_edge = min(processor.size)
image_original_resize = image.resize((shortest_edge, shortest_edge))
# image_padded_square = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
image_patches = [image_original_resize] + patches
image_patches = [
processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0]
for image_patch in image_patches
]
# return torch.stack(image_patches, dim=0)
return image_patches
def resize_and_pad_image(image, target_resolution):
"""
Resize and pad an image to a target resolution while maintaining aspect ratio.
Args:
image (PIL.Image.Image): The input image.
target_resolution (tuple): The target resolution (width, height) of the image.
Returns:
PIL.Image.Image: The resized and padded image.
"""
original_width, original_height = image.size
target_width, target_height = target_resolution
# Determine which dimension (width or height) to fill
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
# Width will be filled completely
new_width = target_width
new_height = min(math.ceil(original_height * scale_w), target_height)
else:
# Height will be filled completely
new_height = target_height
new_width = min(math.ceil(original_width * scale_h), target_width)
# Resize the image
resized_image = image.resize((new_width, new_height))
# Create a new image with the target size and paste the resized image onto it
new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
new_image.paste(resized_image, (paste_x, paste_y))
return new_image
def divide_to_patches(image, patch_size):
"""
Divides an image into patches of a specified size.
Args:
image (PIL.Image.Image): The input image.
patch_size (int): The size of each patch.
Returns:
list: A list of PIL.Image.Image objects representing the patches.
"""
patches = []
width, height = image.size
for i in range(0, height, patch_size):
for j in range(0, width, patch_size):
box = (j, i, j + patch_size, i + patch_size)
patch = image.crop(box)
patches.append(patch)
return patches
from typing import List
import PIL.Image
import torch
import transformers
IGNORE_ID = -100
IMAGE_TOKEN_ID = -200
IMAGE_TOKEN = "<image>"
IMAGE_ATOM_ID = -300
IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, -305]
def construct_image_placeholders(grid):
image_placeholders = [IMAGE_INDICATOR_IDS[0], IMAGE_ATOM_ID, IMAGE_INDICATOR_IDS[1]]
if grid[0] * grid[1] > 1:
for r in range(grid[0]):
for c in range(grid[1]):
image_placeholders.append(IMAGE_ATOM_ID)
if c < grid[1] - 1:
image_placeholders.append(IMAGE_INDICATOR_IDS[2])
if r < grid[0] - 1:
image_placeholders.append(IMAGE_INDICATOR_IDS[3])
image_placeholders.append(IMAGE_INDICATOR_IDS[4])
return image_placeholders
def preprocess_image_ovis(image: PIL.Image.Image, image_processor, crop_size, max_partition=9, covering_threshold=0.9, convert_to_rgb=True):
def _preprocess(img: PIL.Image.Image, side):
# first resize and preprocess
w, h = img.size
if w == h:
new_width = new_height = side
elif w > h:
new_width = side
new_height = int(h / w * new_width)
else:
new_height = side
new_width = int(w / h * new_height)
new_size = dict(height=new_height, width=new_width)
pixel_values = image_processor.preprocess(img, size=new_size, return_tensors='pt')['pixel_values']
# then pad to square
square_values = torch.zeros([1, 3, side, side], dtype=pixel_values.dtype, device=pixel_values.device)
new_height, new_width = pixel_values.shape[2:]
if new_height == new_width:
square_values[:, :, :, :] = pixel_values
elif new_height > new_width:
from_index = (side - new_width) // 2
square_values[:, :, :, from_index:from_index + new_width] = pixel_values
else:
from_index = (side - new_height) // 2
square_values[:, :, from_index:from_index + new_height, :] = pixel_values
return square_values
def _partition(img, grid):
w, h = img.size
row_height = h // grid[0]
col_width = w // grid[1]
partition = []
for row in range(grid[0]):
for col in range(grid[1]):
left = col * col_width
upper = row * row_height
right = w if col == grid[1] - 1 else (col + 1) * col_width
lower = h if row == grid[0] - 1 else (row + 1) * row_height
partition.append((left, upper, right, lower))
return partition
def _covering_area(left, upper, right, lower, side):
w = right - left
h = lower - upper
w, h = max(w, h), min(w, h)
if w > side:
h = h / w * side
w = side
return w * h
def _get_best_grid(img, side):
img_area = img.size[0] * img.size[1]
candidate_grids = []
for i in range(1, max_partition + 1):
for j in range(1, max_partition + 1):
if i * j <= max_partition:
candidate_grids.append((i, j))
all_grids = []
good_grids = []
for grid in candidate_grids:
partition = _partition(img, grid)
covering_ratio = sum([_covering_area(*p, side) for p in partition]) / img_area
assert covering_ratio <= 1.0
all_grids.append((grid, covering_ratio))
if covering_ratio > covering_threshold:
good_grids.append((grid, covering_ratio))
if len(good_grids) > 0:
# pick the good partition with minimum #sub_images and break the tie using covering_ratio
return sorted(good_grids, key=lambda x: (x[0][0] * x[0][1], -x[1]))[0][0]
else:
# pick the partition with maximum covering_ratio and break the tie using #sub_images
return sorted(all_grids, key=lambda x: (-x[1], x[0][0] * x[0][1]))[0][0]
if convert_to_rgb and image.mode != 'RGB':
image = image.convert('RGB')
# sides = self.get_image_size()
sides = [crop_size, crop_size]
if sides[0] != sides[1]:
raise ValueError('get_image_size() returns non-square size')
side = sides[0]
grid = _get_best_grid(image, side)
partition = _partition(image, grid)
crops = [image.crop(p) for p in partition]
if len(crops) > 1:
crops.insert(0, image)
# pixel_values = torch.cat([_preprocess(crop, side) for crop in crops], dim=0)
pixel_values = [_preprocess(crop, side) for crop in crops] # cat in the outer function
image_placeholders = construct_image_placeholders(grid)
return pixel_values, image_placeholders
def ovis_template_process(data_dict):
image = data_dict['images']
input_ids = data_dict['input_ids']
labels = data_dict['labels']
placeholder = []
new_input_ids = []
new_labels = []
for img in image:
placeholder.append(img[1])
indices = torch.nonzero(input_ids==IMAGE_TOKEN_ID).squeeze(1)
assert len(placeholder) == len(indices)
cnt = 0
idx = 0
for ids in input_ids:
if ids == IMAGE_TOKEN_ID:
for i in placeholder[cnt]:
new_input_ids.append(i)
new_labels.append(-100)
cnt += 1
idx += 1
else:
new_input_ids.append(input_ids[idx])
new_labels.append(labels[idx])
idx += 1
assert len(new_input_ids) == len(new_labels)
assert len(placeholder) == cnt
data_dict['images'] = [img[0] for img in data_dict['images']] # (3,3,448,448)
data_dict['input_ids'] = torch.tensor(new_input_ids)
data_dict['labels'] = torch.tensor(new_labels)
return data_dict
def pad_truncate_sequence(multimodal_max_length, sequences: List[torch.Tensor], batch_first: bool = True, padding_value: float = 0.0, left_padding: bool = False) -> torch.Tensor:
if not left_padding:
pad_sequence = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=batch_first, padding_value=padding_value)
return pad_sequence[:,:multimodal_max_length]
else:
pad_sequence = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[0]) for i in sequences],batch_first=True, padding_value=padding_value).flip(dims=[1])
return pad_sequence[:,multimodal_max_length:] |