| import random |
| from typing import Optional, Union, Dict, Any, List |
|
|
| from einops import rearrange, repeat |
| import torch |
| import math |
| import PIL.Image |
| import PIL.ImageSequence |
| import numpy as np |
| import PIL |
| from PIL import Image |
|
|
| from transformers.utils import TensorType, requires_backends, is_torch_dtype, is_torch_device |
| from transformers.image_processing_utils import BaseImageProcessor, BatchFeature |
| from transformers import AutoImageProcessor |
| from transformers.image_transforms import to_channel_dimension_format |
| from transformers.image_utils import ( |
| ImageInput, |
| make_list_of_images, |
| valid_images, |
| is_torch_tensor, |
| is_batched, |
| to_numpy_array, |
| infer_channel_dimension_format, |
| ChannelDimension |
| ) |
| from torchvision.ops.boxes import box_area |
| from torchvision.transforms import functional as F |
| from torchvision.transforms.transforms import InterpolationMode |
| from torchvision import transforms |
|
|
| def recursive_converter(converter, value): |
| if isinstance(value, list): |
| new_value = [] |
| for v in value: |
| new_value += [recursive_converter(converter, v)] |
| return new_value |
| else: |
| return converter(value) |
|
|
| def box_iou(boxes1, area1, boxes2, eps=1e-5): |
| area2 = box_area(boxes2) |
|
|
| lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) |
| rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) |
|
|
| wh = (rb - lt).clamp(min=0) |
| inter = wh[:, :, 0] * wh[:, :, 1] |
|
|
| union = area1[:, None] + area2 - inter |
|
|
| iou = inter / (union+eps) |
| return iou, union |
|
|
| available_anchor_strategy = ['docowl', 'random', 'highest', 'last', 'llava'] |
|
|
| grid_dict = { |
| 'grid_33':[ |
| (1,1), |
| (1,2),(2,1), |
| (1,3),(3,1), |
| (2,2),(1,4),(4,1), |
| (1,5),(5,1), |
| (1,6),(6,1),(2,3),(3,2), |
| (1,7),(7,1), |
| (4,2),(2,4),(1,8),(8,1), |
| (3,3),(1,9),(9,1)], |
| 'grid_squ_3x3':[ |
| (1,1),(2,2),(3,3) |
| ], |
| 'grid_squ_4':[ |
| (2,2),(1,3),(1,4),(3,1),(4,1) |
| ], |
| 'grid_squ_6':[ |
| (2,2),(1,3),(1,4),(3,1),(4,1), (2,3),(3,2) |
| ], |
| 'grid_squ_2':[ |
| (2,1) |
| ], |
| 'grid_squ_9':[ |
| (1,1), |
| (1,2),(2,1), |
| (1,3),(3,1), |
| (2,2),(1,4),(4,1), |
| (1,5),(5,1), |
| (1,6),(6,1),(2,3),(3,2), |
| (1,7),(7,1), |
| (4,2),(2,4),(1,8),(8,1), |
| (3,3),(1,9),(9,1)], |
| } |
|
|
| cut_prompt_template_dict = { |
| 'v0': lambda img_token, h, w: f''.join([f"{img_token}" for i in range(h) for j in range(w)]), |
| 'v1': lambda img_token, h, w: f'Cut to {h} rows {w} columns, '+ ' '.join([f"subimg({i},{j}){img_token}"for i in range(h) for j in range(w)]), |
| 'v1_global': lambda img_token, h, w: f'Cut to {h} rows {w} columns with a global view, '+ ' '.join([f"subimg({i},{j}){img_token}"for i in range(h) for j in range(w)]+[f"global_view{img_token}"]), |
| 'v2_global': lambda img_token, h, w: f'Cut to {h} rows {w} columns with a global view\n'+ '\n'.join([' '.join([f"subimg({i},{j}){img_token}" for j in range(w)]) for i in range(h)])+f"\nglobal_view{img_token}", |
| } |
|
|
| def anchor_rank(anchors, anchors_areas, input_image_size, eps=1e-5): |
| |
|
|
| |
| |
| input_image_bbox = torch.tensor([0, 0, input_image_size[1], input_image_size[0]]).unsqueeze(0) |
|
|
| boxes1 = anchors |
| boxes2 = input_image_bbox |
| boxes3 = anchors.clone() |
| |
| boxes3[:,3] = input_image_size[0]/input_image_size[1]*anchors[:,2] |
| |
| area1 = anchors_areas |
| |
| iou, _ = box_iou(boxes1, area1, boxes2) |
| iou = iou.squeeze(1) |
| shape_iou, _ = box_iou(boxes1, area1, boxes3) |
| shape_iou = shape_iou.diag() |
| |
| index = torch.argmax(shape_iou*100+iou,dim=0) |
| return index |
|
|
| def select_best_resolution(anchors, anchors_areas, input_image_size): |
| """ |
| Selects the best resolution from a list of possible resolutions based on the original size. |
| |
| Args: |
| original_size (tuple): The original size of the image in the format (width, height). |
| possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. |
| |
| Returns: |
| tuple: The best fit resolution in the format (width, height). |
| """ |
| original_size = (input_image_size[1], input_image_size[0]) |
| possible_resolutions = [(_[2], _[3]) for _ in anchors] |
|
|
| original_width, original_height = original_size |
| best_fit = None |
| max_effective_resolution = 0 |
| min_wasted_resolution = float('inf') |
|
|
| index = 0 |
| for i, (width, height) in enumerate(possible_resolutions): |
| scale = min(width / original_width, height / original_height) |
| downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) |
| effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) |
| wasted_resolution = (width * height) - effective_resolution |
|
|
| if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): |
| max_effective_resolution = effective_resolution |
| min_wasted_resolution = wasted_resolution |
| best_fit = (width, height) |
| index = i |
|
|
| return index |
|
|
| def build_cut_shape_indices(cut_shape): |
| |
| cut_shape_indices = [] |
| for shape in cut_shape: |
| n=shape[0]*shape[1] |
| indices = torch.cat([ |
| repeat(torch.tensor(shape),'l -> n l',n=n), |
| torch.arange(n).unsqueeze(1) |
| ], dim=1) |
| assert indices.shape[0] == n |
| assert indices.shape[1] == 3 |
|
|
| cut_shape_indices.append(indices) |
| cut_shape_indices = torch.cat(cut_shape_indices,dim=0).long() |
| return cut_shape_indices |
| |
| class AnchorResize(torch.nn.Module): |
| |
| def __init__(self, image_size, anchors, interpolation=InterpolationMode.BILINEAR, antialias=None, anchor_strategy='docowl'): |
| super().__init__() |
| self.image_size = image_size |
| |
| self.anchors = torch.tensor( |
| [[0, 0, _[1]*image_size[1], _[0]*image_size[0]] |
| for _ in anchors], requires_grad=False |
| ) |
| |
| self.anchor_areas = box_area(self.anchors) |
|
|
| self.interpolation = interpolation |
| self.antialias = antialias |
| self.anchor_strategy = anchor_strategy |
| assert self.anchor_strategy in available_anchor_strategy |
|
|
| def resize_global(self, img): |
| return F.resize(img, self.image_size, self.interpolation, max_size=None, antialias=self.antialias) |
|
|
| def forward(self, img, skip_resize=False): |
| """ |
| Args: |
| img (PIL Image or Tensor): Image to be scaled. |
| |
| Returns: |
| PIL Image or Tensor: Rescaled image. |
| """ |
| if self.anchor_strategy == 'docowl': |
| selected_anchor = anchor_rank(self.anchors, self.anchor_areas, (img.size[1], img.size[0])) |
| elif self.anchor_strategy == 'random': |
| selected_anchor = random.randint(0,len(self.anchors)-1) |
| elif self.anchor_strategy == 'highest': |
| |
| selected_anchor = torch.argmax(self.anchors[:,2]*self.anchors[:,3]*100-torch.abs(self.anchors[:,2]-self.anchors[:,3])) |
| elif self.anchor_strategy == 'last': |
| selected_anchor = len(self.anchors)-1 |
| elif self.anchor_strategy == 'llava': |
| selected_anchor = select_best_resolution(self.anchors, self.anchor_areas, (img.size[1], img.size[0])) |
| else: |
| selected_anchor = None |
| assert selected_anchor is not None |
|
|
| target_size = self.anchors[selected_anchor][2:].tolist() |
| if skip_resize: |
| |
| return selected_anchor |
| return F.resize(img, [target_size[1],target_size[0]], self.interpolation, max_size=None, antialias=self.antialias), selected_anchor |
|
|
| def __repr__(self) -> str: |
| detail = f"(size={self.image_size}, anchor={self.anchors}, interpolation={self.interpolation.value}, antialias={self.antialias})" |
| return f"{self.__class__.__name__}{detail}" |
|
|
| class CutMixin: |
| def __init__(self, cut_cfg={"anchors": "grid_squ_6", "anchor_strategy": "docowl", "cut_prompt": "v2", "add_global": True, "cut_prob": 1.0}) -> None: |
| if cut_cfg is None: |
| self.cut_enable = False |
| return |
| else: |
| self.cut_enable = True |
| image_size = self.image_size |
| anchors = cut_cfg.get('anchors','grid_33') |
| anchor_strategy = cut_cfg.get('anchor_strategy','docowl') |
| cut_prompt = cut_cfg.get('cut_prompt','v0') |
| self.cut_prob = cut_cfg.get('cut_prob', 1.0) |
| |
| self.force_shape_cut = cut_cfg.get('force_shape_cut', False) |
| force_shape_cut_anchors = cut_cfg.get('force_shape_cut_anchors', 'force_shape_cut_anchors') |
|
|
|
|
| self.add_global = cut_cfg.get('add_global', False) |
| |
| |
| if isinstance(image_size, int): |
| image_size = (image_size, image_size) |
| self.image_size = image_size |
|
|
| if anchors in grid_dict: |
| anchors = grid_dict[anchors] |
| else: |
| anchors = eval(anchors) |
| self.anchors = [tuple(_) for _ in anchors] |
| self.anchor_max = max([max(_) for _ in self.anchors]) |
| self.resizer = AnchorResize(image_size=image_size, anchors=anchors, interpolation=InterpolationMode.BICUBIC, anchor_strategy=anchor_strategy) |
|
|
| if force_shape_cut_anchors in grid_dict: |
| force_shape_cut_anchors = grid_dict[force_shape_cut_anchors] |
| else: |
| force_shape_cut_anchors = eval(force_shape_cut_anchors) |
| self.force_shape_cut_anchors = [tuple(_) for _ in force_shape_cut_anchors] |
| self.force_shape_cut_anchors_max = max([max(_) for _ in self.force_shape_cut_anchors]) |
| |
|
|
|
|
| self.old_resizer = transforms.Resize(image_size,interpolation=InterpolationMode.BICUBIC) |
|
|
| |
| self.image_transform = transforms.Compose(self.image_transform.transforms[1:]) |
| if self.add_global: |
| self.cut_prompt_template = cut_prompt_template_dict[cut_prompt+'_global'] |
| else: |
| self.cut_prompt_template = cut_prompt_template_dict[cut_prompt] |
|
|
| self.media_tokens = ["<|image|>", "<|video|>"] |
|
|
|
|
|
|
| def _process_image(self, images): |
| new_images = [] |
| cut_shape = [] |
| for image in images: |
| raw_image = image |
| |
| image, selected_anchor = self.resizer(image) |
| image_input = self.image_transform(image) |
| cut_shape.append((image_input.shape[1]//self.image_size[0], image_input.shape[2]//self.image_size[1])) |
| image_input = rearrange(image_input, 'C (num_h h) (num_w w) -> (num_h num_w) C h w', h=self.image_size[0], w=self.image_size[1]) |
| |
| new_images.append(image_input) |
| |
| if self.add_global: |
| new_images.append(self.image_transform(self.resizer.resize_global(raw_image)).unsqueeze(0)) |
| cut_shape.append((1,1)) |
|
|
| new_images = torch.cat(new_images,dim=0) |
| cut_shape_indices = build_cut_shape_indices(cut_shape) |
| return new_images, cut_shape, cut_shape_indices |
|
|
| class mPLUGOwl3BatchFeature(BatchFeature): |
| r""" |
| Extend from BatchFeature for supporting various image size |
| """ |
| def __init__(self, data: Optional[Dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None): |
| super().__init__(data) |
| self.convert_to_tensors(tensor_type=tensor_type) |
|
|
| def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None): |
| if tensor_type is None: |
| return self |
| |
| is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type) |
|
|
| def converter(value): |
| try: |
| if not is_tensor(value): |
| tensor = as_tensor(value) |
| return tensor |
| except: |
| if key == "overflowing_values": |
| raise ValueError("Unable to create tensor returning overflowing values of different lengths. ") |
| raise ValueError( |
| "Unable to create tensor, you should probably activate padding " |
| "with 'padding=True' to have batched tensors with the same length." |
| ) |
|
|
|
|
| for key, value in self.items(): |
| self[key] = recursive_converter(converter, value) |
| return self |
| |
| def to(self, *args, **kwargs) -> "mPLUGOwl3BatchFeature": |
| requires_backends(self, ["torch"]) |
| import torch |
|
|
| def cast_tensor(v): |
| |
| if torch.is_floating_point(v): |
| |
| return v.to(*args, **kwargs) |
| elif device is not None: |
| return v.to(device=device) |
| else: |
| return v |
|
|
| new_data = {} |
| device = kwargs.get("device") |
| |
| if device is None and len(args) > 0: |
| |
| arg = args[0] |
| if is_torch_dtype(arg): |
| |
| pass |
| elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int): |
| device = arg |
| else: |
| |
| raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.") |
| |
| for k, v in self.items(): |
| new_data[k] = recursive_converter(cast_tensor, v) |
| self.data = new_data |
| return self |
|
|
|
|
| class mPLUGOwl3ImageProcessor(BaseImageProcessor, CutMixin): |
| model_input_names = ["pixel_values"] |
|
|
| def __init__( |
| self, |
| image_size, |
| mean=[0.5, 0.5, 0.5], |
| std=[0.5, 0.5, 0.5], |
| **kwargs): |
| super().__init__(**kwargs) |
| self.image_size = image_size |
| self.image_transform = transforms.Compose([ |
| transforms.Resize((image_size, image_size), interpolation=Image.BICUBIC), |
| transforms.ToTensor(), |
| transforms.Normalize(mean, std), |
| ]) |
| CutMixin.__init__(self) |
|
|
| def preprocess( |
| self, |
| images: Union[Image.Image, List[Image.Image]], |
| cut_enable=True, |
| **kwargs |
| ) -> mPLUGOwl3BatchFeature: |
| if isinstance(images, Image.Image): |
| images_list = [images] |
| else: |
| images_list = images |
|
|
| if self.cut_enable and cut_enable: |
| image_data, cut_shape, cut_shape_indices = self._process_image(images_list) |
| else: |
| image_data = [self.image_transform(self.resizer.resize_global(image)) for image in images_list] |
| image_data = torch.stack(image_data, dim=0) |
| cut_shape = cut_shape_indices = None |
| |
| return mPLUGOwl3BatchFeature(data={'pixel_values': image_data, 'cut_shape':cut_shape, 'cut_shape_indices':cut_shape_indices}) |
| |
| def to_dict(self): |
| encoder_dict = super().to_dict() |
| pop_keys = ['image_transform', 'resizer', 'old_resizer', 'cut_prompt_template'] |
| for pk in pop_keys: |
| encoder_dict.pop(pk, None) |
| return encoder_dict |
|
|
| AutoImageProcessor.register("mPLUGOwl3ImageProcessor", mPLUGOwl3ImageProcessor) |
|
|