| import torch |
| import ast |
| import math |
| from PIL import Image |
|
|
|
|
| def has_fn(model, fn_name): |
| """Check if model has a function fn_name""" |
| return callable(getattr(model, fn_name, None)) |
|
|
| def exists(val): |
| return val is not None |
|
|
| def num_params(module, filter_to_trainable=False): |
| """Returns the number of parameters in the module, or optionally only the trainable parameters""" |
| if filter_to_trainable: |
| return sum(p.numel() for p in module.parameters() if p.requires_grad) |
| else: |
| return sum(p.numel() for p in module.parameters()) |
|
|
| def hasattr_recursive(obj, att): |
| """ |
| Check if obj has nested attribute |
| Example: hasattr_recursive(obj, 'a.b.c') is equivalent to hasattr(obj, 'a') and hasattr(obj.a, 'b') and hasattr(obj.a.b, 'c') |
| """ |
| if att == "": |
| return True |
| i = att.find(".") |
| if i < 0: |
| return hasattr(obj, att) |
| else: |
| try: |
| return hasattr_recursive(getattr(obj, att[:i]), att[i + 1 :]) |
| except: |
| return False |
|
|
| def getattr_recursive(obj, att): |
| """ |
| Return nested attribute of obj |
| Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c |
| """ |
| if att == "": |
| return obj |
| i = att.find(".") |
| if i < 0: |
| return getattr(obj, att) |
| else: |
| return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :]) |
|
|
|
|
| def setattr_recursive(obj, att, val): |
| """ |
| Set nested attribute of obj |
| Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val |
| """ |
| if "." in att: |
| obj = getattr_recursive(obj, ".".join(att.split(".")[:-1])) |
| setattr(obj, att.split(".")[-1], val) |
| |
| |
| def stack_with_padding(list_of_tensors, padding_value=0, padding_side="right"): |
| """ |
| Stack a list of tensors with padding on one side |
| Args: |
| list_of_tensors (list[torch.Tensor]): List of tensors to stack |
| padding_value (int, optional): Value to pad with. Defaults to 0. |
| padding_side (str, optional): Side to pad on. Defaults to "right". |
| Returns: |
| torch.Tensor: Stacked tensors |
| """ |
| max_tokens = max(tensor.size(0) for tensor in list_of_tensors) |
| padded_tensors = [] |
| for tensor in list_of_tensors: |
| num_tokens = tensor.size(0) |
| if len(tensor.size()) == 1: |
| padding = torch.full( |
| (max_tokens - num_tokens,), |
| padding_value, |
| dtype=tensor.dtype, |
| device=tensor.device, |
| ) |
| else: |
| padding = torch.full( |
| (max_tokens - num_tokens, tensor.size(1)), |
| padding_value, |
| dtype=tensor.dtype, |
| device=tensor.device, |
| ) |
| padded_tensor = ( |
| torch.cat((tensor, padding), dim=0) |
| if padding_side == "right" |
| else torch.cat((padding, tensor), dim=0) |
| ) |
| padded_tensors.append(padded_tensor) |
| return torch.stack(padded_tensors) |
| |
| |
| def check_embedding_fns(lang_model): |
| """Checks for and attempts to set {get/set}_{input/output}_embeddings functions to the model""" |
| if not has_fn(lang_model, "get_input_embeddings"): |
| if hasattr_recursive(lang_model, "transformer.wte"): |
| lang_model.get_input_embeddings = lambda: lang_model.transformer.wte |
| elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"): |
| lang_model.get_input_embeddings = lambda: lang_model.decoder.embed_tokens |
| else: |
| raise ValueError( |
| "We require the language encoder to have a get_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py." |
| ) |
|
|
| if not has_fn(lang_model, "set_input_embeddings"): |
| if hasattr_recursive(lang_model, "transformer.wte"): |
| lang_model.set_input_embeddings = lambda x: setattr_recursive( |
| lang_model, "transformer.wte", x |
| ) |
| elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"): |
| lang_model.set_input_embeddings = lambda x: setattr_recursive( |
| lang_model, "model.decoder.embed_tokens", x |
| ) |
| else: |
| raise ValueError( |
| "We require the language encoder to have a set_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py." |
| ) |
|
|
| if not has_fn(lang_model, "get_output_embeddings"): |
| if hasattr_recursive(lang_model, "lm_head"): |
| lang_model.get_output_embeddings = lambda: lang_model.lm_head |
| else: |
| raise ValueError( |
| "We require the language encoder to have a get_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py." |
| ) |
|
|
| if not has_fn(lang_model, "set_output_embeddings"): |
| if hasattr_recursive(lang_model, "lm_head"): |
| lang_model.set_output_embeddings = lambda x: setattr_recursive( |
| lang_model, "lm_head", x |
| ) |
| else: |
| raise ValueError( |
| "We require the language encoder to have a set_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py." |
| ) |
|
|
|
|
| def has_fn(model, fn_name): |
| """Check if model has a function fn_name""" |
| return callable(getattr(model, fn_name, None)) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def unpad_image(tensor, original_size, keep_original_shape=False): |
| """ |
| Unpads a PyTorch tensor of a padded and resized image. |
| |
| Args: |
| tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format. |
| original_size (tuple): The original size of the image (height, width). |
| |
| Returns: |
| torch.Tensor: The unpadded image tensor. |
| """ |
| original_width, original_height = original_size |
| current_height, current_width = tensor.shape[1:] |
|
|
| original_aspect_ratio = original_width / original_height |
| current_aspect_ratio = current_width / current_height |
|
|
| if original_aspect_ratio > current_aspect_ratio: |
| scale_factor = current_width / original_width |
| new_height = int(original_height * scale_factor) |
| padding = (current_height - new_height) // 2 |
| if keep_original_shape: |
| attention_mask = torch.ones((current_height, current_width), device=tensor.device) |
| attention_mask[:padding, :] = 0 |
| attention_mask[current_height - padding:, :] = 0 |
| return tensor, attention_mask |
| else: |
| unpadded_tensor = tensor[:, padding:current_height - padding, :] |
| return unpadded_tensor, None |
| else: |
| scale_factor = current_height / original_height |
| new_width = int(original_width * scale_factor) |
| padding = (current_width - new_width) // 2 |
| if keep_original_shape: |
| attention_mask = torch.ones((current_height, current_width), device=tensor.device) |
| attention_mask[:, :padding] = 0 |
| attention_mask[:, current_width - padding:] = 0 |
| return tensor, attention_mask |
| else: |
| unpadded_tensor = tensor[:, :, padding:current_width - padding] |
| return unpadded_tensor, None |
|
|
|
|
| def select_best_resolution(original_size, possible_resolutions): |
| """ |
| Selects the best resolution from a list of possible resolutions based on the original size. |
| |
| Args: |
| original_size (tuple): The original size of the image in the format (width, height). |
| possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. |
| |
| Returns: |
| tuple: The best fit resolution in the format (width, height). |
| """ |
| original_width, original_height = original_size |
| best_fit = None |
| max_effective_resolution = 0 |
| min_wasted_resolution = float('inf') |
|
|
| for width, height in possible_resolutions: |
| scale = min(width / original_width, height / original_height) |
| downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) |
| effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) |
| wasted_resolution = (width * height) - effective_resolution |
|
|
| if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): |
| max_effective_resolution = effective_resolution |
| min_wasted_resolution = wasted_resolution |
| best_fit = (width, height) |
|
|
| return best_fit |
|
|
|
|
| def resize_and_pad_image(image, target_resolution): |
| """ |
| Resize and pad an image to a target resolution while maintaining aspect ratio. |
| |
| Args: |
| image (PIL.Image.Image): The input image. |
| target_resolution (tuple): The target resolution (width, height) of the image. |
| |
| Returns: |
| PIL.Image.Image: The resized and padded image. |
| """ |
| original_width, original_height = image.size |
| target_width, target_height = target_resolution |
|
|
| scale_w = target_width / original_width |
| scale_h = target_height / original_height |
|
|
| if scale_w < scale_h: |
| new_width = target_width |
| new_height = min(math.ceil(original_height * scale_w), target_height) |
| else: |
| new_height = target_height |
| new_width = min(math.ceil(original_width * scale_h), target_width) |
|
|
| |
| resized_image = image.resize((new_width, new_height)) |
|
|
| new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0)) |
| paste_x = (target_width - new_width) // 2 |
| paste_y = (target_height - new_height) // 2 |
| new_image.paste(resized_image, (paste_x, paste_y)) |
|
|
| return new_image |
|
|
|
|
| def divide_to_patches(image, patch_size): |
| """ |
| Divides an image into patches of a specified size. |
| |
| Args: |
| image (PIL.Image.Image): The input image. |
| patch_size (int): The size of each patch. |
| |
| Returns: |
| list: A list of PIL.Image.Image objects representing the patches. |
| """ |
| patches = [] |
| width, height = image.size |
| for i in range(0, height, patch_size): |
| for j in range(0, width, patch_size): |
| box = (j, i, j + patch_size, i + patch_size) |
| patch = image.crop(box) |
| patches.append(patch) |
|
|
| return patches |
|
|
|
|
| def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): |
| """ |
| Calculate the shape of the image patch grid after the preprocessing for images of any resolution. |
| |
| Args: |
| image_size (tuple): The size of the input image in the format (width, height). |
| grid_pinpoints (str): A string representation of a list of possible resolutions. |
| patch_size (int): The size of each image patch. |
| |
| Returns: |
| tuple: The shape of the image patch grid in the format (width, height). |
| """ |
| if type(grid_pinpoints) is list: |
| possible_resolutions = grid_pinpoints |
| else: |
| possible_resolutions = ast.literal_eval(grid_pinpoints) |
| width, height = select_best_resolution(image_size, possible_resolutions) |
| return width // patch_size, height // patch_size |
|
|
|
|
| def process_anyres_image(image, processor, grid_pinpoints): |
| """ |
| Process an image with variable resolutions. |
| |
| Args: |
| image (PIL.Image.Image): The input image to be processed. |
| processor: The image processor object. |
| grid_pinpoints (str): A string representation of a list of possible resolutions. |
| |
| Returns: |
| torch.Tensor: A tensor containing the processed image patches. |
| """ |
| |
| if type(grid_pinpoints) is list: |
| possible_resolutions = grid_pinpoints |
| else: |
| possible_resolutions = ast.literal_eval(grid_pinpoints) |
| best_resolution = select_best_resolution(image.size, possible_resolutions) |
| image_padded = resize_and_pad_image(image, best_resolution) |
|
|
| processor_size = processor.transforms[0].size |
| patches = divide_to_patches(image_padded, processor_size[0]) |
|
|
| image_original_resize = image.resize((processor_size[0], processor_size[0])) |
|
|
| image_patches = [image_original_resize] + patches |
| image_patches = [processor(image_patch) |
| for image_patch in image_patches] |
| return torch.stack(image_patches, dim=0) |
|
|
|
|
| def expand2square(pil_img, background_color): |
| width, height = pil_img.size |
| if width == height: |
| return pil_img |
| elif width > height: |
| result = Image.new(pil_img.mode, (width, width), background_color) |
| result.paste(pil_img, (0, (width - height) // 2)) |
| return result |
| else: |
| result = Image.new(pil_img.mode, (height, height), background_color) |
| result.paste(pil_img, ((height - width) // 2, 0)) |
| return result |
| |
|
|
| def process_images(images, image_processor, model_cfg): |
| image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) |
| new_images = [] |
| if image_aspect_ratio == 'pad': |
| for image in images: |
| image = expand2square(image, tuple(int(x*255) for x in image_processor.transforms[-1].mean)) |
| image = image_processor(image) |
| new_images.append(image) |
| elif image_aspect_ratio in ["anyres", "anyres-legacy"]: |
| base_img_size = image_processor.transforms[0].size[0] |
| for image in images: |
| image = process_anyres_image(image, image_processor, [[base_img_size,base_img_size*2], |
| [base_img_size*2,base_img_size], |
| [base_img_size*2,base_img_size*2], |
| [base_img_size*3,base_img_size], |
| [base_img_size,base_img_size*3]]) |
|
|
| |
| |
| new_images.append(image) |
| else: |
| return image_processor(images) |
| if all(x.shape == new_images[0].shape for x in new_images): |
| new_images = torch.stack(new_images, dim=0) |
| return new_images |
|
|
|
|
|
|