Valley2.5 / utils.py
Hyggge's picture
feat: modify file type of *.py, *.txt, etc. to change storage method
64c250f
from PIL import Image
from io import BytesIO
import base64
import math
import ast
import re
import torch
from transformers import StoppingCriteria
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
GANDALF_TOKEN_INDEX = -300
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "</s>"
DEFAULT_UNK_TOKEN = "<unk>"
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
DEFAULT_VIDEO_TOKEN = "<video>"
DEFAULT_VIDEO_FRAME_TOKEN = "<vi_frame>"
DEFAULT_VI_START_TOKEN = "<vi_start>"
DEFAULT_VI_END_TOKEN = "<vi_end>"
DEFAULT_EOC_TOKEN = "<eoc>"
COR_START_TOKEN = "<cor>"
COR_END_TOKEN = "<\cor>"
SEQ_MAX_LEN = 50000
BLACK_IMG_ENV = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x03\x00\x00\x00\x03\x08\x02\x00\x00\x00\xd9J"\xe8\x00\x00\x00\x12IDAT\x08\x1dcd\x80\x01F\x06\x18`d\x80\x01\x00\x00Z\x00\x04we\x03N\x00\x00\x00\x00IEND\xaeB`\x82'
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Args:
image_size (tuple): The size of the input image in the format (width, height).
grid_pinpoints (str): A string representation of a list of possible resolutions.
patch_size (int): The size of each image patch.
Returns:
tuple: The shape of the image patch grid in the format (width, height).
"""
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
# Use regex to extract the range from the input string
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
range_start = tuple(map(int, matches[0]))
range_end = tuple(map(int, matches[-1]))
# Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
grid_pinpoints = [
(i, j)
for i in range(range_start[0], range_end[0] + 1)
for j in range(range_start[1], range_end[1] + 1)
]
# Multiply all elements by patch_size
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints)
width, height = select_best_resolution(image_size, possible_resolutions)
return width // patch_size, height // patch_size
def select_best_resolution(original_size, possible_resolutions):
"""
Selects the best resolution from a list of possible resolutions based on the original size.
Args:
original_size (tuple): The original size of the image in the format (width, height).
possible_resolutions (list): A list of possible resolutions in the format
[(width1, height1), (width2, height2), ...].
Returns:
tuple: The best fit resolution in the format (width, height).
"""
original_width, original_height = original_size
best_fit = None
max_effective_resolution = 0
min_wasted_resolution = float("inf")
for width, height in possible_resolutions:
# Calculate the downscaled size to keep the aspect ratio
scale = min(width / original_width, height / original_height)
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
# Calculate effective and wasted resolutions
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
wasted_resolution = (width * height) - effective_resolution
if effective_resolution > max_effective_resolution or \
(effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
max_effective_resolution = effective_resolution
min_wasted_resolution = wasted_resolution
best_fit = (width, height)
return best_fit
def unpad_image(tensor, original_size):
"""
Unpads a PyTorch tensor of a padded and resized image.
Args:
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
original_size (tuple): The original size of the image (height, width).
Returns:
torch.Tensor: The unpadded image tensor.
"""
original_width, original_height = original_size
current_height, current_width = tensor.shape[1:]
# Compute aspect ratios
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
# Determine padding size and direction
if original_aspect_ratio > current_aspect_ratio:
# Padding was added to the height
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding: current_height - padding, :]
else:
# Padding was added to the width
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding: current_width - padding]
return unpadded_tensor
def process_anyres_image(image, processor, grid_pinpoints):
"""
Process an image with variable resolutions.
Args:
image (PIL.Image.Image): The input image to be processed.
processor: The image processor object.
grid_pinpoints (str): A string representation of a list of possible resolutions.
Returns:
torch.Tensor: A tensor containing the processed image patches.
"""
# Convert grid_pinpoints from string to list
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
try:
patch_size = processor.size["height"]
except Exception:
patch_size = processor.size["shortest_edge"]
assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
# Use regex to extract the range from the input string
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
range_start = tuple(map(int, matches[0]))
range_end = tuple(map(int, matches[-1]))
# Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
grid_pinpoints = [
(i, j)
for i in range(range_start[0], range_end[0] + 1)
for j in range(range_start[1], range_end[1] + 1)
]
# Multiply all elements by patch_size
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints)
best_resolution = select_best_resolution(image.size, possible_resolutions)
image_padded = resize_and_pad_image(image, best_resolution)
patches = divide_to_patches(image_padded, processor.size["height"])
# FIXME: this seems to be a bug that it resizes instead of pad.
# but to keep it consistent with previous, i will keep it as it is
# TODO: uncomment below to ablate with the padding
if isinstance(processor.size, dict):
shortest_edge = processor.size["height"]
else:
shortest_edge = min(processor.size)
image_original_resize = image.resize((shortest_edge, shortest_edge))
# image_padded_square = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
image_patches = [image_original_resize] + patches
image_patches = [
processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0]
for image_patch in image_patches
]
# return torch.stack(image_patches, dim=0)
return image_patches
def resize_and_pad_image(image, target_resolution):
"""
Resize and pad an image to a target resolution while maintaining aspect ratio.
Args:
image (PIL.Image.Image): The input image.
target_resolution (tuple): The target resolution (width, height) of the image.
Returns:
PIL.Image.Image: The resized and padded image.
"""
original_width, original_height = image.size
target_width, target_height = target_resolution
# Determine which dimension (width or height) to fill
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
# Width will be filled completely
new_width = target_width
new_height = min(math.ceil(original_height * scale_w), target_height)
else:
# Height will be filled completely
new_height = target_height
new_width = min(math.ceil(original_width * scale_h), target_width)
# Resize the image
resized_image = image.resize((new_width, new_height))
# Create a new image with the target size and paste the resized image onto it
new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
new_image.paste(resized_image, (paste_x, paste_y))
return new_image
def divide_to_patches(image, patch_size):
"""
Divides an image into patches of a specified size.
Args:
image (PIL.Image.Image): The input image.
patch_size (int): The size of each patch.
Returns:
list: A list of PIL.Image.Image objects representing the patches.
"""
patches = []
width, height = image.size
for i in range(0, height, patch_size):
for j in range(0, width, patch_size):
box = (j, i, j + patch_size, i + patch_size)
patch = image.crop(box)
patches.append(patch)
return patches
from typing import List
import PIL.Image
import torch
import transformers
IGNORE_ID = -100
IMAGE_TOKEN_ID = -200
IMAGE_TOKEN = "<image>"
IMAGE_ATOM_ID = -300
IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, -305]
def construct_image_placeholders(grid):
image_placeholders = [IMAGE_INDICATOR_IDS[0], IMAGE_ATOM_ID, IMAGE_INDICATOR_IDS[1]]
if grid[0] * grid[1] > 1:
for r in range(grid[0]):
for c in range(grid[1]):
image_placeholders.append(IMAGE_ATOM_ID)
if c < grid[1] - 1:
image_placeholders.append(IMAGE_INDICATOR_IDS[2])
if r < grid[0] - 1:
image_placeholders.append(IMAGE_INDICATOR_IDS[3])
image_placeholders.append(IMAGE_INDICATOR_IDS[4])
return image_placeholders
def preprocess_image_ovis(image: PIL.Image.Image, image_processor, crop_size, max_partition=9, covering_threshold=0.9, convert_to_rgb=True):
def _preprocess(img: PIL.Image.Image, side):
# first resize and preprocess
w, h = img.size
if w == h:
new_width = new_height = side
elif w > h:
new_width = side
new_height = int(h / w * new_width)
else:
new_height = side
new_width = int(w / h * new_height)
new_size = dict(height=new_height, width=new_width)
pixel_values = image_processor.preprocess(img, size=new_size, return_tensors='pt')['pixel_values']
# then pad to square
square_values = torch.zeros([1, 3, side, side], dtype=pixel_values.dtype, device=pixel_values.device)
new_height, new_width = pixel_values.shape[2:]
if new_height == new_width:
square_values[:, :, :, :] = pixel_values
elif new_height > new_width:
from_index = (side - new_width) // 2
square_values[:, :, :, from_index:from_index + new_width] = pixel_values
else:
from_index = (side - new_height) // 2
square_values[:, :, from_index:from_index + new_height, :] = pixel_values
return square_values
def _partition(img, grid):
w, h = img.size
row_height = h // grid[0]
col_width = w // grid[1]
partition = []
for row in range(grid[0]):
for col in range(grid[1]):
left = col * col_width
upper = row * row_height
right = w if col == grid[1] - 1 else (col + 1) * col_width
lower = h if row == grid[0] - 1 else (row + 1) * row_height
partition.append((left, upper, right, lower))
return partition
def _covering_area(left, upper, right, lower, side):
w = right - left
h = lower - upper
w, h = max(w, h), min(w, h)
if w > side:
h = h / w * side
w = side
return w * h
def _get_best_grid(img, side):
img_area = img.size[0] * img.size[1]
candidate_grids = []
for i in range(1, max_partition + 1):
for j in range(1, max_partition + 1):
if i * j <= max_partition:
candidate_grids.append((i, j))
all_grids = []
good_grids = []
for grid in candidate_grids:
partition = _partition(img, grid)
covering_ratio = sum([_covering_area(*p, side) for p in partition]) / img_area
assert covering_ratio <= 1.0
all_grids.append((grid, covering_ratio))
if covering_ratio > covering_threshold:
good_grids.append((grid, covering_ratio))
if len(good_grids) > 0:
# pick the good partition with minimum #sub_images and break the tie using covering_ratio
return sorted(good_grids, key=lambda x: (x[0][0] * x[0][1], -x[1]))[0][0]
else:
# pick the partition with maximum covering_ratio and break the tie using #sub_images
return sorted(all_grids, key=lambda x: (-x[1], x[0][0] * x[0][1]))[0][0]
if convert_to_rgb and image.mode != 'RGB':
image = image.convert('RGB')
# sides = self.get_image_size()
sides = [crop_size, crop_size]
if sides[0] != sides[1]:
raise ValueError('get_image_size() returns non-square size')
side = sides[0]
grid = _get_best_grid(image, side)
partition = _partition(image, grid)
crops = [image.crop(p) for p in partition]
if len(crops) > 1:
crops.insert(0, image)
# pixel_values = torch.cat([_preprocess(crop, side) for crop in crops], dim=0)
pixel_values = [_preprocess(crop, side) for crop in crops] # cat in the outer function
image_placeholders = construct_image_placeholders(grid)
return pixel_values, image_placeholders
def ovis_template_process(data_dict):
image = data_dict['images']
input_ids = data_dict['input_ids']
labels = data_dict['labels']
placeholder = []
new_input_ids = []
new_labels = []
for img in image:
placeholder.append(img[1])
indices = torch.nonzero(input_ids==IMAGE_TOKEN_ID).squeeze(1)
assert len(placeholder) == len(indices)
cnt = 0
idx = 0
for ids in input_ids:
if ids == IMAGE_TOKEN_ID:
for i in placeholder[cnt]:
new_input_ids.append(i)
new_labels.append(-100)
cnt += 1
idx += 1
else:
new_input_ids.append(input_ids[idx])
new_labels.append(labels[idx])
idx += 1
assert len(new_input_ids) == len(new_labels)
assert len(placeholder) == cnt
data_dict['images'] = [img[0] for img in data_dict['images']] # (3,3,448,448)
data_dict['input_ids'] = torch.tensor(new_input_ids)
data_dict['labels'] = torch.tensor(new_labels)
return data_dict
def pad_truncate_sequence(multimodal_max_length, sequences: List[torch.Tensor], batch_first: bool = True, padding_value: float = 0.0, left_padding: bool = False) -> torch.Tensor:
if not left_padding:
pad_sequence = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=batch_first, padding_value=padding_value)
return pad_sequence[:,:multimodal_max_length]
else:
pad_sequence = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[0]) for i in sequences],batch_first=True, padding_value=padding_value).flip(dims=[1])
return pad_sequence[:,multimodal_max_length:]