| | import copy |
| | import io |
| | import json |
| | import os |
| | import random |
| | import warnings |
| | import logging |
| | from typing import Any |
| | from copy import deepcopy |
| | from distinctipy import distinctipy |
| |
|
| | import numpy as np |
| | from PIL import Image, ImageDraw |
| | import cv2 |
| | import torch |
| | from torch.utils.data import Dataset |
| | import torchvision.transforms as T |
| | import torch.nn.functional as F |
| | from torchvision.transforms.functional import InterpolationMode |
| | from datasets import Dataset as HFDataset |
| | from datasets import DatasetDict, load_from_disk |
| | from transformers import AutoConfig, AutoTokenizer |
| | from pycocotools import mask |
| |
|
| | from mmengine import print_log |
| | from mmengine.config import Config, ConfigDict |
| |
|
| | from xtuner.registry import BUILDER |
| | from xtuner.dataset.huggingface import process_hf_dataset, build_origin_dataset |
| | from xtuner.utils import DEFAULT_IMAGE_TOKEN |
| |
|
| | from .process_functions import (dynamic_preprocess, preprocess_internlm, |
| | preprocess_mpt, preprocess_phi3, preprocess, |
| | vcr_decode_mask_fn, preprocess_phi3_debug) |
| | from .utils import (expand2square, expand2square_mask, DEFAULT_VISION_PROMPT_TOKEN, |
| | VPT_CONTEXT_TOKEN, VPT_START_TOKEN, VPT_END_TOKEN, RGB_NAME) |
| | from .process_functions import (point_rendering, box_rendering, image_blending, contour_rendering) |
| |
|
| |
|
| | class InternVLDataset(Dataset): |
| | IMAGENET_MEAN = (0.485, 0.456, 0.406) |
| | IMAGENET_STD = (0.229, 0.224, 0.225) |
| |
|
| | def __init__(self, |
| | model_path, |
| | data_path=None, |
| | image_folder=None, |
| | dataset_map_fn=None, |
| | annotation_load_fn=None, |
| | dynamic_image_size=False, |
| | pad_image_to_square=False, |
| | repeat_time=1, |
| | max_length=8192, |
| | num_dynamic_patch=None, |
| | lazy_load=True, |
| | group_by_length=False, |
| | tokenizer=None, |
| | support_prompt_types=["rectangle"], |
| | pseudo_two_images_mode=False, |
| | ot_image_processor=None, |
| | vfm_name="RADIO",): |
| | super().__init__() |
| |
|
| | self.max_length = max_length |
| | self.dataset_map_fn = dataset_map_fn |
| | self.annotation_load_fn = annotation_load_fn |
| | self.lazy_load = lazy_load |
| | self.dynamic_image_size = dynamic_image_size |
| | self.pad_image_to_square = pad_image_to_square |
| | self.group_by_length = group_by_length |
| | self.support_prompt_types = support_prompt_types |
| | self.pseudo_two_images_mode = pseudo_two_images_mode |
| | self.ot_image_processor = ot_image_processor |
| | self.vfm_name = vfm_name |
| | if vfm_name in ['DINOv2', 'ConvNext']: |
| | self.ot_image_processor.do_center_crop=False |
| | self.ot_image_processor.do_resize=False |
| |
|
| | |
| | self.cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True) |
| | |
| | self.template = self.cfg.template |
| | if num_dynamic_patch is not None and len(num_dynamic_patch) == 2: |
| | self.min_dynamic_patch = num_dynamic_patch[0] |
| | self.max_dynamic_patch = num_dynamic_patch[1] |
| | else: |
| | self.min_dynamic_patch = self.cfg.min_dynamic_patch |
| | self.max_dynamic_patch = self.cfg.max_dynamic_patch |
| | self.downsample_ratio = self.cfg.downsample_ratio |
| | self.image_size = self.cfg.force_image_size |
| | self.use_thumbnail = self.cfg.use_thumbnail |
| | patch_size = self.cfg.vision_config.patch_size |
| | self.patch_token = int((self.image_size // patch_size)**2 * (self.downsample_ratio**2)) |
| | if tokenizer is not None: |
| | self.tokenizer = tokenizer |
| | else: |
| | self.tokenizer = AutoTokenizer.from_pretrained( |
| | model_path, trust_remote_code=True) |
| | self._add_special_tokens() |
| | self.transform = T.Compose([ |
| | T.Lambda(lambda img: img.convert('RGB') |
| | if img.mode != 'RGB' else img), |
| | T.Resize((self.image_size, self.image_size)), |
| | T.ToTensor(), |
| | T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD) |
| | ]) |
| |
|
| | json_data, hf_json_data = self.annotation_load_fn(data_path, repeat_time, image_folder=image_folder) |
| | |
| | if json_data is not None: |
| | self.json_data = json_data |
| | hf_json_data = DatasetDict({'train': HFDataset.from_list(hf_json_data)}) |
| | if self.lazy_load: |
| | self.text_data = build_origin_dataset(hf_json_data, 'train') |
| | else: |
| | raise NotImplementedError |
| | |
| | self.image_folder = image_folder |
| | self._max_refetch = 1000 |
| | self.tcs_loader = None |
| | |
| | def _add_special_tokens(self): |
| | special_tokens = [VPT_CONTEXT_TOKEN,] |
| | num_new_tokens = self.tokenizer.add_tokens(special_tokens, special_tokens=True) |
| |
|
| | @property |
| | def modality_length(self): |
| | length_list = [] |
| | for data_dict in self.text_data: |
| | if self.lazy_load: |
| | cur_len = 100 |
| | else: |
| | cur_len = len(data_dict['input_ids']) |
| | if data_dict.get('image', None) is None: |
| | cur_len = -cur_len |
| | length_list.append(cur_len) |
| | return length_list |
| | |
| | def _rand_another(self): |
| | return np.random.randint(0, len(self.text_data)) |
| |
|
| | def __len__(self): |
| | return len(self.text_data) |
| | |
| | def __getitem__(self, index) -> Any: |
| | for _ in range(self._max_refetch + 1): |
| | data = self.prepare_data(index) |
| | |
| | if data is None: |
| | index = self._rand_another() |
| | continue |
| | return data |
| | |
| | def prepare_data(self, index): |
| | if hasattr(self, 'json_data'): |
| | data_dict = copy.deepcopy(self.json_data[index]) |
| | data_dict.update(self.text_data[index]) |
| | else: |
| | data_dict = copy.deepcopy(self.text_data[index]) |
| | |
| | if self.lazy_load: |
| | result = self.dataset_map_fn(data_dict) |
| | if result is None: |
| | return None |
| | data_dict.update(result) |
| | |
| | if 'image' in data_dict and data_dict['image'] is not None and len(data_dict['image']) != 0: |
| | if type(data_dict['image']) == list or self.pseudo_two_images_mode: |
| | ret = self.multi_modal_multi_image_get_item(data_dict) |
| | else: |
| | ret = self.multi_modal_get_item(data_dict) |
| | elif 'video' in data_dict and data_dict['video'] is not None and data_dict['video'] != '': |
| | ret = self.video_get_item(data_dict) |
| | else: |
| | ret = self.pure_text_get_item(data_dict) |
| | return ret |
| | else: |
| | raise NotImplementedError |
| | |
| | def get_preprocess_function(self): |
| | |
| | if self.template == "Hermes-2": |
| | preprocess_function = preprocess_mpt |
| | elif self.template == "internlm2-chat" or "internvl2_5": |
| | preprocess_function = preprocess_internlm |
| | self.template = "internlm2-chat" |
| | elif self.template == "phi3-chat": |
| | preprocess_function = preprocess_phi3_debug |
| | else: |
| | preprocess_function = preprocess |
| | return preprocess_function |
| |
|
| | def load_image(self, image_path): |
| | |
| | if self.tcs_loader is not None and 's3://' in image_path: |
| | return self.tcs_loader(image_path) |
| | return Image.open(image_path).convert('RGB') |
| | |
| | def decode_mask(self, object_masks, ori_height, ori_width): |
| | binary_masks = [] |
| | for object_mask in object_masks: |
| | if isinstance(object_mask, dict): |
| | if isinstance(object_mask["counts"], list): |
| | |
| | object_mask = mask.frPyObjects(object_mask, ori_height, ori_width) |
| | m = mask.decode(object_mask) |
| | m = m.astype(np.uint8).squeeze() |
| | elif object_mask: |
| | rles = mask.frPyObjects(object_mask, ori_height, ori_width) |
| | rle = mask.merge(rles) |
| | m = mask.decode(rle).astype(np.uint8).squeeze() |
| | else: |
| | m = np.zeros((ori_height, ori_width), dtype=np.uint8) |
| | binary_masks.append(m) |
| | if len(binary_masks) == 0: |
| | binary_masks.append(np.zeros((ori_height, ori_width), dtype=np.uint8)) |
| | masks = np.stack(binary_masks, axis=0) |
| | if self.pad_image_to_square: |
| | masks = expand2square_mask(masks) |
| | |
| | return masks |
| |
|
| | def multi_modal_get_item(self, data_item): |
| | |
| | if DEFAULT_IMAGE_TOKEN not in data_item['conversations'][0]['value']: |
| | data_item['conversations'][0]['value'] = DEFAULT_IMAGE_TOKEN + '\n' + data_item['conversations'][0]['value'] |
| | |
| | |
| | image_path = os.path.join(self.image_folder, data_item['image']) |
| |
|
| | |
| | try: |
| | image = self.load_image(image_path) |
| | except Exception as e: |
| | print(f'Error: {e}', flush=True) |
| | print_log(f'Error: {e}', logger='current') |
| | return None |
| | if image is None: |
| | return None |
| | ori_width, ori_height = image.size |
| | if ori_width < 10 or ori_height < 10: |
| | return None |
| | |
| | |
| |
|
| | |
| | merged_visual_prompts = cv2.imread(image_path) |
| | if merged_visual_prompts is None: |
| | merged_visual_prompts = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) |
| | regions = np.zeros(shape=(1, ori_height, ori_width), dtype=np.uint8) |
| | has_visual_prompts = False |
| | if 'annotation' in data_item and data_item['annotation'] is not None and len(data_item['annotation']) > 0: |
| | annotations = data_item['annotation'] |
| | sampled_inds = data_item.get('sampled_inds', list(range(len(annotations)))) |
| | if self.annotation_load_fn.__name__ == 'RegionShortConversationVCRDataset_load_fn': |
| | bboxes = [annotations[idx]['bbox'] for idx in sampled_inds] |
| | segms = [annotations[idx]['segmentation'] for idx in sampled_inds] |
| | regions = vcr_decode_mask_fn(bboxes, segms, ori_height, ori_width) |
| | regions = (regions > 0.0).astype(np.uint8) |
| | elif self.annotation_load_fn.__name__ == 'MDPVBoxOCRDataset_load_fn': |
| | bboxes = [annotations[idx]['bbox'] for idx in sampled_inds] |
| | regions = np.zeros(shape=(len(bboxes), ori_height, ori_width), dtype=np.uint8) |
| | for bidx, bbox in enumerate(bboxes): |
| | x0, y0, x1, y1 = bbox |
| | regions[bidx, y0:y1, x0:x1] = 1 |
| | else: |
| | segms = [annotations[idx]['segmentation'] for idx in sampled_inds] |
| | regions = self.decode_mask(segms, ori_height=ori_height, ori_width=ori_width) |
| | try: |
| | contour_rendering(merged_visual_prompts, regions) |
| | except Exception as e: |
| | pass |
| | has_visual_prompts = True |
| | merged_visual_prompts = Image.fromarray(cv2.cvtColor(merged_visual_prompts, cv2.COLOR_BGR2RGB)) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | if self.dynamic_image_size: |
| | try: |
| | images, _regions, merged_regions = dynamic_preprocess( |
| | image, regions, merged_visual_prompts, min_num=self.min_dynamic_patch, max_num=self.max_dynamic_patch, |
| | image_size=self.image_size, use_thumbnail=self.use_thumbnail) |
| | except AssertionError as e: |
| | return None |
| | elif self.pad_image_to_square: |
| | image = expand2square( |
| | image, |
| | tuple(int(x * 255) for x in self.IMAGENET_MEAN)) |
| | images = [image] |
| | merged_visual_prompts = expand2square( |
| | merged_visual_prompts, |
| | tuple(int(x * 255) for x in self.IMAGENET_MEAN)) |
| | merged_regions = [merged_visual_prompts] |
| | else: |
| | images = [image] |
| | merged_regions = [merged_visual_prompts] |
| |
|
| | |
| | pixel_values = [self.transform(image) for image in images] |
| | pixel_values = torch.stack(pixel_values) |
| |
|
| | merged_visual_prompts = [self.transform(merged_region) for merged_region in merged_regions] |
| | merged_visual_prompts = torch.stack(merged_visual_prompts) |
| |
|
| | transformed_visual_prompts = [] |
| | for region in _regions: |
| | transformed_regions = [] |
| | for _region in region: |
| | resized_region = cv2.resize( |
| | _region[:, :, np.newaxis], dsize=(self.image_size, self.image_size), |
| | interpolation=cv2.INTER_NEAREST_EXACT) |
| | transformed_regions.append(torch.from_numpy(resized_region).squeeze(-1)) |
| | transformed_visual_prompts.append(torch.stack(transformed_regions)) |
| | visual_prompts = torch.stack(transformed_visual_prompts) |
| |
|
| | assert merged_visual_prompts.shape[:2] == pixel_values.shape[:2] |
| | |
| | if self.vfm_name == "DINOv2": |
| | OT_FORCE_IMAGE_SIZE = 512 |
| | elif self.vfm_name in ["RADIO", "ConvNext"]: |
| | OT_FORCE_IMAGE_SIZE = 1024 |
| | else: |
| | raise NotImplementedError |
| | |
| | image = self.load_image(image_path) |
| | w, h = image.size |
| | if w > h: |
| | target_size = (OT_FORCE_IMAGE_SIZE, int(h/w*OT_FORCE_IMAGE_SIZE)) |
| | else: |
| | target_size = (int(w/h*OT_FORCE_IMAGE_SIZE), OT_FORCE_IMAGE_SIZE) |
| | resized_image = image.resize(target_size) |
| | cur_w, cur_h = resized_image.size |
| | padded_image = np.zeros(shape=(OT_FORCE_IMAGE_SIZE, OT_FORCE_IMAGE_SIZE, 3), dtype=np.uint8) * 255 |
| | padded_image[:cur_h, :cur_w, :] = np.array(resized_image) |
| | ot_pixel_values = self.ot_image_processor(images=padded_image, return_tensors='pt').pixel_values |
| |
|
| | ot_visual_prompts = torch.tensor(regions).\ |
| | to(ot_pixel_values.dtype).to(ot_pixel_values.device) |
| | h, w = ot_visual_prompts.shape[-2:] |
| | if h > w: |
| | target_size = (OT_FORCE_IMAGE_SIZE, int(w/h*OT_FORCE_IMAGE_SIZE)) |
| | else: |
| | target_size = (int(h/w*OT_FORCE_IMAGE_SIZE), OT_FORCE_IMAGE_SIZE) |
| | resized_ot_visual_prompts = F.interpolate(ot_visual_prompts.unsqueeze(1), size=target_size, mode="bilinear").squeeze(1) |
| | resized_padded_ot_visual_prompts = resized_ot_visual_prompts.new_zeros((resized_ot_visual_prompts.shape[0], OT_FORCE_IMAGE_SIZE, OT_FORCE_IMAGE_SIZE)) |
| | resized_padded_ot_visual_prompts[:, :target_size[0], :target_size[1]] = resized_ot_visual_prompts |
| |
|
| | |
| | num_patches = pixel_values.size(0) |
| | if not self.dynamic_image_size: |
| | assert num_patches == 1, f'The number of patches should be 1, but got {num_patches}.' |
| | |
| | |
| | preprocess_function = self.get_preprocess_function() |
| |
|
| | |
| | if has_visual_prompts: |
| | region_ids = [[region_id+1 for region_id in range(ot_visual_prompts.shape[0])],] |
| | object_tokens_str = "" |
| | for fidx, object_ids_fidx in enumerate(region_ids): |
| | object_tokens_str = object_tokens_str + f"Regions in the image: " |
| | for object_id in object_ids_fidx: |
| | object_tokens_str = object_tokens_str + f"<region-{object_id}>{VPT_CONTEXT_TOKEN}, " |
| | object_tokens_str = object_tokens_str[:-1] + ".\n" |
| | else: |
| | object_tokens_str = "" |
| |
|
| | ret = preprocess_function(self.template, [deepcopy(data_item['conversations'])], |
| | self.tokenizer, [self.patch_token * num_patches], |
| | group_by_length=self.group_by_length, ds_name="XXX", |
| | num_image=1, object_tokens_str=object_tokens_str) |
| | |
| | |
| | ret = dict( |
| | input_ids=ret['input_ids'][0], |
| | labels=ret['labels'][0], |
| | attention_mask=ret['attention_mask'][0], |
| | pixel_values=merged_visual_prompts, |
| | merged_visual_prompts=pixel_values, |
| | image_flags=torch.tensor([1] * num_patches, dtype=torch.long), |
| | num_patches=[num_patches,], |
| | visual_prompts=visual_prompts.flatten(0, 1), |
| | num_vprompts=[visual_prompts.shape[0],], |
| | vprompt_flags=[[1]*visual_prompts.shape[0], ] if has_visual_prompts else [[0]*visual_prompts.shape[0],], |
| | num_images=1, |
| | ot_pixel_values=ot_pixel_values, |
| | ot_visual_prompts=resized_padded_ot_visual_prompts, |
| | region_ids=[[region_id+1 for region_id in range(visual_prompts.shape[0])],], |
| | ) |
| |
|
| | return ret |
| | |
| | def multi_modal_multi_image_get_item(self, data_item): |
| | image_name_list = data_item['image'] |
| | |
| | image_path_list = [os.path.join(self.image_folder, image_name) for image_name in image_name_list] |
| | images = [self.load_image(image_path) for image_path in image_path_list] |
| | if any([item is None for item in images]): |
| | return None |
| | merged_visual_prompts = [cv2.imread(image_path) for image_path in image_path_list] |
| | for idx, item in enumerate(merged_visual_prompts): |
| | if item is not None: |
| | continue |
| | merged_visual_prompts[idx] = cv2.cvtColor(np.asarray(image_path_list[idx]), cv2.COLOR_RGB2BGR) |
| |
|
| | |
| | gt_region_id = -1 |
| | visual_prompts_list, object_ids = [], [] |
| | if 'pos_annotations' in data_item and data_item['pos_annotations'] is not None and len(data_item['pos_annotations']) > 0: |
| | pos_annotations = data_item['pos_annotations'] |
| | neg_annotations = data_item['neg_annotations'] |
| |
|
| | name_rgb = random.choice(RGB_NAME) |
| | color_name, color = [], [] |
| | for k, v in name_rgb.items(): |
| | color_name.append(str(k)) |
| | color.append(v) |
| | color = color[0] |
| | color_name = color_name[0] |
| | color_anno_i = (color[2], color[1], color[0]) |
| | for fidx in range(len(pos_annotations)-1): |
| | ori_width, ori_height = images[fidx].size |
| | regions = self.decode_mask(pos_annotations[fidx], ori_height=ori_height, ori_width=ori_width) |
| | visual_prompts_list.append(regions) |
| | object_ids.append([]) |
| | for region in regions: |
| | contours, hierarchy = cv2.findContours(region, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) |
| | cv2.drawContours(merged_visual_prompts[fidx], contours, -1, color=color_anno_i, thickness=2) |
| | object_ids[fidx].append(1) |
| | ori_width, ori_height = images[-1].size |
| | pos_neg_segms = pos_annotations[-1] + neg_annotations[-1] |
| | regions = self.decode_mask(pos_neg_segms, ori_height=ori_height, ori_width=ori_width) |
| | visual_prompts_list.append(regions) |
| | |
| | random_id = list(range(1, len(regions)+1)) |
| | random.shuffle(random_id) |
| | try: |
| | contour_rendering(merged_visual_prompts[-1], regions, random_id) |
| | object_ids.append([_id for _id in random_id]) |
| |
|
| | choice_names = [f"{chr(i)}" for i in range(65,91)] |
| | if len(regions) > len(choice_names) - 1: |
| | valid_num = len(choice_names) - 1 |
| | else: |
| | valid_num = len(regions) |
| | region_ids = random_id[:valid_num] |
| | choice_names = choice_names[:valid_num+1] |
| | gt_region_id = region_ids[0] if not data_item['is_disappear'] else -1 |
| |
|
| | region_ids.sort() |
| | multi_choices_str = "" |
| | gt_choice_str = "" |
| | for choice_name, region_id in zip(choice_names[:-1], region_ids): |
| | multi_choices_str = multi_choices_str + f"{choice_name}. {region_id}\n" |
| | if region_id == gt_region_id: |
| | assert gt_choice_str == "" |
| | gt_choice_str = gt_choice_str + f"{choice_name}" |
| |
|
| | multi_choices_str = multi_choices_str + f"{choice_names[-1]}. None of the above choices are correct\n" |
| | if gt_choice_str == "" or data_item['is_disappear'] or len(pos_annotations[-1]) == 0: |
| | gt_choice_str = f"{choice_names[-1]}" |
| | |
| | conversations = data_item['conversations'] |
| | for i, conversation in enumerate(conversations): |
| | conversation_value = conversation['value'] |
| | conversation_value = conversation_value.format(color=color_name, choices=multi_choices_str, answer=gt_choice_str) |
| | conversation['value'] = conversation_value |
| | data_item['conversations'] = conversations |
| | except Exception as e: |
| | pass |
| | else: |
| | pass |
| |
|
| | merged_visual_prompts = [Image.fromarray(cv2.cvtColor(item, cv2.COLOR_BGR2RGB)) for item in merged_visual_prompts] |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if self.dynamic_image_size: |
| | num_patches_list, images_list, merged_regions_list, crop_regions_list, num_vprompts_list = [], [], [], [], [] |
| | for image, visual_prompts, merged_visual_prompt in zip(images, visual_prompts_list, merged_visual_prompts): |
| | try: |
| | _images, regions, merged_regions = dynamic_preprocess( |
| | image, visual_prompts, merged_visual_prompt, min_num=self.min_dynamic_patch, max_num=self.max_dynamic_patch, |
| | image_size=self.image_size, use_thumbnail=self.use_thumbnail) |
| | except AssertionError as e: |
| | return None |
| | images_list.extend(_images) |
| | merged_regions_list.extend(merged_regions) |
| | crop_regions_list.extend(regions) |
| | num_patches_list.append(len(_images)) |
| | num_vprompts_list.append(len(regions)) |
| | else: |
| | raise NotImplementedError |
| |
|
| | |
| | pixel_values = [self.transform(image) for image in images_list] |
| | pixel_values = torch.stack(pixel_values) |
| |
|
| | merged_visual_prompts = [self.transform(merged_region) for merged_region in merged_regions_list] |
| | merged_visual_prompts = torch.stack(merged_visual_prompts) |
| |
|
| | transformed_visual_prompts = [] |
| | for region in crop_regions_list: |
| | transformed_regions = [] |
| | for _region in region: |
| | resized_region = cv2.resize( |
| | _region[:, :, np.newaxis], dsize=(self.image_size, self.image_size), |
| | interpolation=cv2.INTER_NEAREST_EXACT) |
| | transformed_regions.append(torch.from_numpy(resized_region).squeeze(-1)) |
| | transformed_visual_prompts.append(torch.stack(transformed_regions)) |
| | visual_prompts = torch.stack(transformed_visual_prompts) |
| |
|
| | assert merged_visual_prompts.shape[:2] == pixel_values.shape[:2] |
| | |
| | if self.vfm_name == "DINOv2": |
| | OT_FORCE_IMAGE_SIZE = 512 |
| | elif self.vfm_name in ["RADIO", "ConvNext"]: |
| | OT_FORCE_IMAGE_SIZE = 1024 |
| | else: |
| | raise NotImplementedError |
| |
|
| | ot_pixel_values = [] |
| | for fi, image in enumerate(images): |
| | w, h = image.size |
| | if w > h: |
| | target_size = (OT_FORCE_IMAGE_SIZE, int(h/w*OT_FORCE_IMAGE_SIZE)) |
| | else: |
| | target_size = (int(w/h*OT_FORCE_IMAGE_SIZE), OT_FORCE_IMAGE_SIZE) |
| | resized_image = image.resize(target_size) |
| | cur_w, cur_h = resized_image.size |
| | padded_image = np.ones(shape=(OT_FORCE_IMAGE_SIZE, OT_FORCE_IMAGE_SIZE, 3), dtype=np.uint8) * 255 |
| | padded_image[:cur_h, :cur_w, :] = np.array(resized_image) |
| |
|
| | ot_pixel_values.append(self.ot_image_processor(images=Image.fromarray(padded_image), return_tensors='pt').pixel_values) |
| | |
| | ot_pixel_values = torch.cat(ot_pixel_values) |
| |
|
| | ot_visual_prompts = torch.from_numpy(np.concatenate(visual_prompts_list, axis=0)).\ |
| | to(ot_pixel_values.dtype).to(ot_pixel_values.device) |
| | h, w = ot_visual_prompts.shape[-2:] |
| | if h > w: |
| | target_size = (OT_FORCE_IMAGE_SIZE, int(w/h*OT_FORCE_IMAGE_SIZE)) |
| | else: |
| | target_size = (int(h/w*OT_FORCE_IMAGE_SIZE), OT_FORCE_IMAGE_SIZE) |
| | resized_ot_visual_prompts = F.interpolate(ot_visual_prompts.unsqueeze(1), size=target_size, mode="bilinear").squeeze(1) |
| | resized_padded_ot_visual_prompts = resized_ot_visual_prompts.new_zeros((resized_ot_visual_prompts.shape[0], OT_FORCE_IMAGE_SIZE, OT_FORCE_IMAGE_SIZE)) |
| | resized_padded_ot_visual_prompts[:, :target_size[0], :target_size[1]] = resized_ot_visual_prompts |
| |
|
| | |
| | num_patches = pixel_values.size(0) |
| | if not self.dynamic_image_size: |
| | assert num_patches == 1, f'The number of patches should be 1, but got {num_patches}.' |
| | |
| | |
| | preprocess_function = self.get_preprocess_function() |
| |
|
| | |
| | if gt_region_id != object_ids[-1][0] and gt_region_id != -1: |
| | print("query object id doesn't match with the candidate ids.") |
| | return None |
| | region_ids = [[gt_region_id for _ in object_ids[fidx]] |
| | for fidx in range(len(num_vprompts_list)-1)] |
| | object_tokens_str = "" |
| | for fidx, object_ids_fidx in enumerate(region_ids): |
| | object_tokens_str = object_tokens_str + f"Objects in Image-{fidx+1}: " |
| | for object_id in range(1, len(object_ids_fidx)+1): |
| | object_tokens_str = object_tokens_str + f"<query object>{VPT_CONTEXT_TOKEN}, " |
| | object_tokens_str = object_tokens_str[:-2] + ".\n" |
| | sorted_indices = sorted(range(len(object_ids[-1])), key=lambda k: object_ids[-1][k]) |
| | sorted_cand_object_ids = [] |
| | object_tokens_str = object_tokens_str + f"Objects in Image-{len(object_ids)}: " |
| | for sorted_idx in sorted_indices: |
| | object_id = object_ids[-1][sorted_idx] |
| | object_tokens_str = object_tokens_str + f"<object-{object_id}>{VPT_CONTEXT_TOKEN}, " |
| | sorted_cand_object_ids.append(object_id) |
| | object_tokens_str = object_tokens_str[:-2] + ".\n" |
| | region_ids = region_ids + [sorted_cand_object_ids, ] |
| |
|
| | total_vprompts = resized_padded_ot_visual_prompts.shape[0] |
| | cand_visual_prompts = resized_padded_ot_visual_prompts[total_vprompts-num_vprompts_list[-1]:] |
| | sorted_cand_visual_prompts = [] |
| | for sorted_idx in sorted_indices: |
| | sorted_cand_visual_prompts.append(cand_visual_prompts[sorted_idx]) |
| | sorted_cand_visual_prompts = torch.stack(sorted_cand_visual_prompts) |
| | resized_padded_ot_visual_prompts = torch.cat( |
| | [resized_padded_ot_visual_prompts[:total_vprompts-num_vprompts_list[-1]], sorted_cand_visual_prompts]) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | ret = preprocess_function(self.template, [deepcopy(data_item['conversations'])], |
| | self.tokenizer, [self.patch_token * num_patch for num_patch in num_patches_list], |
| | group_by_length=self.group_by_length, ds_name="XXX", |
| | num_image=len(num_patches_list), object_tokens_str=object_tokens_str,) |
| | |
| | roi_version = self.dataset_map_fn.__name__ == "match_reasoning_map_fn_roi" |
| |
|
| | |
| | |
| |
|
| | |
| | ret = dict( |
| | input_ids=ret['input_ids'][0], |
| | labels=ret['labels'][0], |
| | attention_mask=ret['attention_mask'][0], |
| | pixel_values=pixel_values if not roi_version else merged_visual_prompts, |
| | merged_visual_prompts=merged_visual_prompts if not roi_version else pixel_values, |
| | image_flags=torch.tensor([1] * num_patches, dtype=torch.long), |
| | num_patches=num_patches_list, |
| | visual_prompts=visual_prompts.flatten(0, 1), |
| | num_vprompts=num_vprompts_list, |
| | vprompt_flags=[[1 for _ in range(nvp)] for nvp in num_vprompts_list], |
| | num_images=len(num_vprompts_list), |
| | ot_pixel_values=ot_pixel_values, |
| | ot_visual_prompts=resized_padded_ot_visual_prompts, |
| | region_ids=region_ids, |
| | ) |
| |
|
| | return ret |
| |
|
| | def video_get_item(self, data_item): |
| | raise NotImplementedError |
| | |
| | def pure_text_get_item(self, data_item): |
| | ori_height = ori_width = 448 |
| | image = Image.new('RGB', (ori_height, ori_width), (255, 255, 255)) |
| | merged_visual_prompts = np.zeros((ori_height, ori_width, 3), dtype=np.uint8) |
| | merged_visual_prompts = Image.fromarray(merged_visual_prompts) |
| | |
| | |
| | image = expand2square( |
| | image, |
| | tuple(int(x * 255) for x in self.IMAGENET_MEAN)) |
| | images = [image] |
| | merged_visual_prompts = expand2square( |
| | merged_visual_prompts, |
| | (0, 0, 0) |
| | ) |
| | merged_regions = [merged_visual_prompts] |
| |
|
| | |
| | pixel_values = [self.transform(image) for image in images] |
| | pixel_values = torch.stack(pixel_values) |
| |
|
| | merged_visual_prompts = [self.transform(merged_region) for merged_region in merged_regions] |
| | merged_visual_prompts = torch.stack(merged_visual_prompts) |
| |
|
| | visual_prompts = torch.zeros(size=( |
| | merged_visual_prompts.shape[0], merged_visual_prompts.shape[-2], merged_visual_prompts.shape[-1]), |
| | dtype=torch.long).to(merged_visual_prompts.device) |
| | |
| | if self.vfm_name == "DINOv2": |
| | OT_FORCE_IMAGE_SIZE = 512 |
| | elif self.vfm_name in ["RADIO", "ConvNext"]: |
| | OT_FORCE_IMAGE_SIZE = 1024 |
| | else: |
| | raise NotImplementedError |
| |
|
| | image = Image.new('RGB', (OT_FORCE_IMAGE_SIZE, OT_FORCE_IMAGE_SIZE), (255, 255, 255)) |
| | ot_pixel_values = self.ot_image_processor(images=image, return_tensors='pt').pixel_values |
| | ot_visual_prompts = torch.zeros((1, OT_FORCE_IMAGE_SIZE, OT_FORCE_IMAGE_SIZE)).\ |
| | to(ot_pixel_values.dtype).to(ot_pixel_values.device) |
| | |
| |
|
| | |
| | num_patches = pixel_values.size(0) |
| | if not self.dynamic_image_size: |
| | assert num_patches == 1, f'The number of patches should be 1, but got {num_patches}.' |
| |
|
| | |
| | preprocess_function = self.get_preprocess_function() |
| |
|
| | |
| | ret = preprocess_function(self.template, [deepcopy(data_item['conversations'])], |
| | self.tokenizer, [self.patch_token * num_patches], text_only=True, |
| | group_by_length=self.group_by_length, ds_name="XXX", |
| | num_image=0) |
| |
|
| | |
| | ret = dict( |
| | input_ids=ret['input_ids'][0], |
| | labels=ret['labels'][0], |
| | attention_mask=ret['attention_mask'][0], |
| | pixel_values=pixel_values, |
| | merged_visual_prompts=merged_visual_prompts, |
| | image_flags=torch.tensor([0] * num_patches, dtype=torch.long), |
| | num_patches=[num_patches, ], |
| | visual_prompts=visual_prompts, |
| | num_vprompts=[1, ], |
| | vprompt_flags=[[0,],], |
| | num_images=1, |
| | ot_pixel_values=ot_pixel_values, |
| | ot_visual_prompts=ot_visual_prompts, |
| | region_ids=[[1,],], |
| | ) |
| |
|
| | return ret |
| |
|
| |
|
| |
|
| |
|