| | from PIL import Image |
| | from io import BytesIO |
| | import base64 |
| | import math |
| | import ast |
| | import re |
| | import torch |
| | from transformers import StoppingCriteria |
| |
|
| | IGNORE_INDEX = -100 |
| | IMAGE_TOKEN_INDEX = -200 |
| | GANDALF_TOKEN_INDEX = -300 |
| | DEFAULT_PAD_TOKEN = "[PAD]" |
| | DEFAULT_EOS_TOKEN = "</s>" |
| | DEFAULT_BOS_TOKEN = "</s>" |
| | DEFAULT_UNK_TOKEN = "<unk>" |
| | DEFAULT_IMAGE_TOKEN = "<image>" |
| | DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>" |
| | DEFAULT_IM_START_TOKEN = "<im_start>" |
| | DEFAULT_IM_END_TOKEN = "<im_end>" |
| | DEFAULT_VIDEO_TOKEN = "<video>" |
| | DEFAULT_VIDEO_FRAME_TOKEN = "<vi_frame>" |
| | DEFAULT_VI_START_TOKEN = "<vi_start>" |
| | DEFAULT_VI_END_TOKEN = "<vi_end>" |
| | DEFAULT_EOC_TOKEN = "<eoc>" |
| | COR_START_TOKEN = "<cor>" |
| | COR_END_TOKEN = "<\cor>" |
| | SEQ_MAX_LEN = 50000 |
| | BLACK_IMG_ENV = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x03\x00\x00\x00\x03\x08\x02\x00\x00\x00\xd9J"\xe8\x00\x00\x00\x12IDAT\x08\x1dcd\x80\x01F\x06\x18`d\x80\x01\x00\x00Z\x00\x04we\x03N\x00\x00\x00\x00IEND\xaeB`\x82' |
| |
|
| |
|
| | def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): |
| | """ |
| | Calculate the shape of the image patch grid after the preprocessing for images of any resolution. |
| | Args: |
| | image_size (tuple): The size of the input image in the format (width, height). |
| | grid_pinpoints (str): A string representation of a list of possible resolutions. |
| | patch_size (int): The size of each image patch. |
| | Returns: |
| | tuple: The shape of the image patch grid in the format (width, height). |
| | """ |
| | if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints: |
| | assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]" |
| | |
| | matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints) |
| | range_start = tuple(map(int, matches[0])) |
| | range_end = tuple(map(int, matches[-1])) |
| | |
| | grid_pinpoints = [ |
| | (i, j) |
| | for i in range(range_start[0], range_end[0] + 1) |
| | for j in range(range_start[1], range_end[1] + 1) |
| | ] |
| | |
| | grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints] |
| | if type(grid_pinpoints) is list: |
| | possible_resolutions = grid_pinpoints |
| | else: |
| | possible_resolutions = ast.literal_eval(grid_pinpoints) |
| | width, height = select_best_resolution(image_size, possible_resolutions) |
| | return width // patch_size, height // patch_size |
| |
|
| | def select_best_resolution(original_size, possible_resolutions): |
| | """ |
| | Selects the best resolution from a list of possible resolutions based on the original size. |
| | Args: |
| | original_size (tuple): The original size of the image in the format (width, height). |
| | possible_resolutions (list): A list of possible resolutions in the format |
| | [(width1, height1), (width2, height2), ...]. |
| | Returns: |
| | tuple: The best fit resolution in the format (width, height). |
| | """ |
| | original_width, original_height = original_size |
| | best_fit = None |
| | max_effective_resolution = 0 |
| | min_wasted_resolution = float("inf") |
| |
|
| | for width, height in possible_resolutions: |
| | |
| | scale = min(width / original_width, height / original_height) |
| | downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) |
| |
|
| | |
| | effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) |
| | wasted_resolution = (width * height) - effective_resolution |
| |
|
| | if effective_resolution > max_effective_resolution or \ |
| | (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): |
| | max_effective_resolution = effective_resolution |
| | min_wasted_resolution = wasted_resolution |
| | best_fit = (width, height) |
| |
|
| | return best_fit |
| |
|
| |
|
| | def unpad_image(tensor, original_size): |
| | """ |
| | Unpads a PyTorch tensor of a padded and resized image. |
| | Args: |
| | tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format. |
| | original_size (tuple): The original size of the image (height, width). |
| | Returns: |
| | torch.Tensor: The unpadded image tensor. |
| | """ |
| | original_width, original_height = original_size |
| | current_height, current_width = tensor.shape[1:] |
| |
|
| | |
| | original_aspect_ratio = original_width / original_height |
| | current_aspect_ratio = current_width / current_height |
| |
|
| | |
| | if original_aspect_ratio > current_aspect_ratio: |
| | |
| | scale_factor = current_width / original_width |
| | new_height = int(original_height * scale_factor) |
| | padding = (current_height - new_height) // 2 |
| | unpadded_tensor = tensor[:, padding: current_height - padding, :] |
| | else: |
| | |
| | scale_factor = current_height / original_height |
| | new_width = int(original_width * scale_factor) |
| | padding = (current_width - new_width) // 2 |
| | unpadded_tensor = tensor[:, :, padding: current_width - padding] |
| |
|
| | return unpadded_tensor |
| |
|
| |
|
| | def process_anyres_image(image, processor, grid_pinpoints): |
| | """ |
| | Process an image with variable resolutions. |
| | Args: |
| | image (PIL.Image.Image): The input image to be processed. |
| | processor: The image processor object. |
| | grid_pinpoints (str): A string representation of a list of possible resolutions. |
| | Returns: |
| | torch.Tensor: A tensor containing the processed image patches. |
| | """ |
| | |
| | if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints: |
| | try: |
| | patch_size = processor.size["height"] |
| | except Exception: |
| | patch_size = processor.size["shortest_edge"] |
| | assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]" |
| | |
| | matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints) |
| | range_start = tuple(map(int, matches[0])) |
| | range_end = tuple(map(int, matches[-1])) |
| | |
| | grid_pinpoints = [ |
| | (i, j) |
| | for i in range(range_start[0], range_end[0] + 1) |
| | for j in range(range_start[1], range_end[1] + 1) |
| | ] |
| | |
| | grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints] |
| |
|
| | if type(grid_pinpoints) is list: |
| | possible_resolutions = grid_pinpoints |
| | else: |
| | possible_resolutions = ast.literal_eval(grid_pinpoints) |
| | best_resolution = select_best_resolution(image.size, possible_resolutions) |
| | image_padded = resize_and_pad_image(image, best_resolution) |
| |
|
| | patches = divide_to_patches(image_padded, processor.size["height"]) |
| |
|
| | |
| | |
| | |
| | if isinstance(processor.size, dict): |
| | shortest_edge = processor.size["height"] |
| | else: |
| | shortest_edge = min(processor.size) |
| | image_original_resize = image.resize((shortest_edge, shortest_edge)) |
| | |
| |
|
| | image_patches = [image_original_resize] + patches |
| | image_patches = [ |
| | processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] |
| | for image_patch in image_patches |
| | ] |
| | |
| | return image_patches |
| |
|
| | def resize_and_pad_image(image, target_resolution): |
| | """ |
| | Resize and pad an image to a target resolution while maintaining aspect ratio. |
| | Args: |
| | image (PIL.Image.Image): The input image. |
| | target_resolution (tuple): The target resolution (width, height) of the image. |
| | Returns: |
| | PIL.Image.Image: The resized and padded image. |
| | """ |
| | original_width, original_height = image.size |
| | target_width, target_height = target_resolution |
| |
|
| | |
| | scale_w = target_width / original_width |
| | scale_h = target_height / original_height |
| |
|
| | if scale_w < scale_h: |
| | |
| | new_width = target_width |
| | new_height = min(math.ceil(original_height * scale_w), target_height) |
| | else: |
| | |
| | new_height = target_height |
| | new_width = min(math.ceil(original_width * scale_h), target_width) |
| |
|
| | |
| | resized_image = image.resize((new_width, new_height)) |
| |
|
| | |
| | new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0)) |
| | paste_x = (target_width - new_width) // 2 |
| | paste_y = (target_height - new_height) // 2 |
| | new_image.paste(resized_image, (paste_x, paste_y)) |
| |
|
| | return new_image |
| |
|
| | def divide_to_patches(image, patch_size): |
| | """ |
| | Divides an image into patches of a specified size. |
| | Args: |
| | image (PIL.Image.Image): The input image. |
| | patch_size (int): The size of each patch. |
| | Returns: |
| | list: A list of PIL.Image.Image objects representing the patches. |
| | """ |
| | patches = [] |
| | width, height = image.size |
| | for i in range(0, height, patch_size): |
| | for j in range(0, width, patch_size): |
| | box = (j, i, j + patch_size, i + patch_size) |
| | patch = image.crop(box) |
| | patches.append(patch) |
| |
|
| | return patches |
| |
|
| |
|
| | from typing import List |
| | import PIL.Image |
| | import torch |
| | import transformers |
| | IGNORE_ID = -100 |
| | IMAGE_TOKEN_ID = -200 |
| | IMAGE_TOKEN = "<image>" |
| | IMAGE_ATOM_ID = -300 |
| | IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, -305] |
| |
|
| |
|
| | def construct_image_placeholders(grid): |
| | image_placeholders = [IMAGE_INDICATOR_IDS[0], IMAGE_ATOM_ID, IMAGE_INDICATOR_IDS[1]] |
| | if grid[0] * grid[1] > 1: |
| | for r in range(grid[0]): |
| | for c in range(grid[1]): |
| | image_placeholders.append(IMAGE_ATOM_ID) |
| | if c < grid[1] - 1: |
| | image_placeholders.append(IMAGE_INDICATOR_IDS[2]) |
| | if r < grid[0] - 1: |
| | image_placeholders.append(IMAGE_INDICATOR_IDS[3]) |
| | image_placeholders.append(IMAGE_INDICATOR_IDS[4]) |
| | return image_placeholders |
| |
|
| |
|
| | def preprocess_image_ovis(image: PIL.Image.Image, image_processor, crop_size, max_partition=9, covering_threshold=0.9, convert_to_rgb=True): |
| | def _preprocess(img: PIL.Image.Image, side): |
| | |
| | w, h = img.size |
| | if w == h: |
| | new_width = new_height = side |
| | elif w > h: |
| | new_width = side |
| | new_height = int(h / w * new_width) |
| | else: |
| | new_height = side |
| | new_width = int(w / h * new_height) |
| | new_size = dict(height=new_height, width=new_width) |
| | pixel_values = image_processor.preprocess(img, size=new_size, return_tensors='pt')['pixel_values'] |
| |
|
| | |
| | square_values = torch.zeros([1, 3, side, side], dtype=pixel_values.dtype, device=pixel_values.device) |
| | new_height, new_width = pixel_values.shape[2:] |
| | if new_height == new_width: |
| | square_values[:, :, :, :] = pixel_values |
| | elif new_height > new_width: |
| | from_index = (side - new_width) // 2 |
| | square_values[:, :, :, from_index:from_index + new_width] = pixel_values |
| | else: |
| | from_index = (side - new_height) // 2 |
| | square_values[:, :, from_index:from_index + new_height, :] = pixel_values |
| |
|
| | return square_values |
| |
|
| | def _partition(img, grid): |
| | w, h = img.size |
| | row_height = h // grid[0] |
| | col_width = w // grid[1] |
| |
|
| | partition = [] |
| | for row in range(grid[0]): |
| | for col in range(grid[1]): |
| | left = col * col_width |
| | upper = row * row_height |
| | right = w if col == grid[1] - 1 else (col + 1) * col_width |
| | lower = h if row == grid[0] - 1 else (row + 1) * row_height |
| | partition.append((left, upper, right, lower)) |
| |
|
| | return partition |
| |
|
| | def _covering_area(left, upper, right, lower, side): |
| | w = right - left |
| | h = lower - upper |
| | w, h = max(w, h), min(w, h) |
| | if w > side: |
| | h = h / w * side |
| | w = side |
| | return w * h |
| |
|
| | def _get_best_grid(img, side): |
| | img_area = img.size[0] * img.size[1] |
| |
|
| | candidate_grids = [] |
| | for i in range(1, max_partition + 1): |
| | for j in range(1, max_partition + 1): |
| | if i * j <= max_partition: |
| | candidate_grids.append((i, j)) |
| |
|
| | all_grids = [] |
| | good_grids = [] |
| | for grid in candidate_grids: |
| | partition = _partition(img, grid) |
| | covering_ratio = sum([_covering_area(*p, side) for p in partition]) / img_area |
| | assert covering_ratio <= 1.0 |
| | all_grids.append((grid, covering_ratio)) |
| | if covering_ratio > covering_threshold: |
| | good_grids.append((grid, covering_ratio)) |
| |
|
| | if len(good_grids) > 0: |
| | |
| | return sorted(good_grids, key=lambda x: (x[0][0] * x[0][1], -x[1]))[0][0] |
| | else: |
| | |
| | return sorted(all_grids, key=lambda x: (-x[1], x[0][0] * x[0][1]))[0][0] |
| |
|
| | if convert_to_rgb and image.mode != 'RGB': |
| | image = image.convert('RGB') |
| |
|
| | |
| | sides = [crop_size, crop_size] |
| | if sides[0] != sides[1]: |
| | raise ValueError('get_image_size() returns non-square size') |
| | side = sides[0] |
| | grid = _get_best_grid(image, side) |
| | partition = _partition(image, grid) |
| | crops = [image.crop(p) for p in partition] |
| | if len(crops) > 1: |
| | crops.insert(0, image) |
| | |
| | pixel_values = [_preprocess(crop, side) for crop in crops] |
| | image_placeholders = construct_image_placeholders(grid) |
| | return pixel_values, image_placeholders |
| |
|
| |
|
| |
|
| | def ovis_template_process(data_dict): |
| | image = data_dict['images'] |
| | input_ids = data_dict['input_ids'] |
| | labels = data_dict['labels'] |
| | placeholder = [] |
| | new_input_ids = [] |
| | new_labels = [] |
| | for img in image: |
| | placeholder.append(img[1]) |
| | |
| | indices = torch.nonzero(input_ids==IMAGE_TOKEN_ID).squeeze(1) |
| | assert len(placeholder) == len(indices) |
| |
|
| | cnt = 0 |
| | idx = 0 |
| | for ids in input_ids: |
| | if ids == IMAGE_TOKEN_ID: |
| | for i in placeholder[cnt]: |
| | new_input_ids.append(i) |
| | new_labels.append(-100) |
| | cnt += 1 |
| | idx += 1 |
| | else: |
| | new_input_ids.append(input_ids[idx]) |
| | new_labels.append(labels[idx]) |
| | idx += 1 |
| | |
| | assert len(new_input_ids) == len(new_labels) |
| | assert len(placeholder) == cnt |
| |
|
| | data_dict['images'] = [img[0] for img in data_dict['images']] |
| | data_dict['input_ids'] = torch.tensor(new_input_ids) |
| | data_dict['labels'] = torch.tensor(new_labels) |
| | return data_dict |
| |
|
| |
|
| | def pad_truncate_sequence(multimodal_max_length, sequences: List[torch.Tensor], batch_first: bool = True, padding_value: float = 0.0, left_padding: bool = False) -> torch.Tensor: |
| | if not left_padding: |
| | pad_sequence = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=batch_first, padding_value=padding_value) |
| | return pad_sequence[:,:multimodal_max_length] |
| | else: |
| | pad_sequence = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[0]) for i in sequences],batch_first=True, padding_value=padding_value).flip(dims=[1]) |
| | return pad_sequence[:,multimodal_max_length:] |