| | import os
|
| | import random
|
| | import re
|
| | import copy
|
| | from dataclasses import dataclass, field
|
| | import json
|
| | import logging
|
| | import pathlib
|
| | from typing import Dict, Optional, Sequence, List
|
| | import bisect
|
| | import torch
|
| | import numpy as np
|
| | import transformers
|
| | from objectrelator.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, \
|
| | DEFAULT_IM_END_TOKEN, DEFAULT_SEG_TOKEN, SEG_TOKEN_INDEX, DEFAULT_CLS_TOKEN, CLS_TOKEN_INDEX, DEFAULT_REGION_TOKEN, \
|
| | REGION_TOKEN_INDEX, REFER_TOKEN_INDEX
|
| | from torch.utils.data import Dataset
|
| | from objectrelator import conversation as conversation_lib
|
| | from objectrelator.model import *
|
| | from objectrelator.mm_utils import tokenizer_image_token
|
| | from objectrelator.mask_config.data_args import TrainingArguments
|
| | from PIL import Image
|
| | from objectrelator.mask_config.config import Config
|
| | from fvcore.common.config import CfgNode
|
| | from detectron2.structures import BoxMode
|
| | import warnings
|
| |
|
| | warnings.filterwarnings('ignore')
|
| | local_rank = None
|
| |
|
| |
|
| | parser = transformers.HfArgumentParser(TrainingArguments)
|
| | training_args = TrainingArguments()
|
| |
|
| | def get_mask_config(config='./objectrelator/mask_config/maskformer2_swin_base_384_bs16_50ep.yaml'):
|
| | cfg_coco = Config.fromfile(config)
|
| | cfg_base = CfgNode.load_yaml_with_base(config, allow_unsafe=True)
|
| | cfg_base.update(cfg_coco.__dict__.items())
|
| | cfg = cfg_base
|
| | cfg = Config(cfg)
|
| | return cfg
|
| | class COCO_panoptic_dataset(Dataset):
|
| | def __init__(self, json_path, tokenizer, data_args, is_train=True):
|
| | super(COCO_panoptic_dataset).__init__()
|
| | if is_train:
|
| | self.panoptic_gt_path = os.path.join(json_path,'panoptic_train2017')
|
| | self.panoptic_image_path = os.path.join(json_path,'train2017')
|
| | self.panoptic_json_path = os.path.join(json_path,'annotations/panoptic_train2017.json')
|
| | self.semantic_gt_path = os.path.join(json_path,'panoptic_semseg_train2017')
|
| | else:
|
| | self.panoptic_gt_path = os.path.join(json_path,'panoptic_val2017')
|
| | self.panoptic_image_path = os.path.join(json_path,'val2017')
|
| | self.panoptic_json_path = os.path.join(json_path,'annotations/panoptic_val2017.json')
|
| | self.semantic_gt_path = os.path.join(json_path,'panoptic_semseg_val2017')
|
| |
|
| | with open(self.panoptic_json_path) as f:
|
| | data = json.load(f)
|
| |
|
| | self.data = data['annotations']
|
| | self.tokenizer = tokenizer
|
| | self.data_args = data_args
|
| | self.mask_format = 'polygon'
|
| | coco_class_ids = [cat['id'] for cat in data['categories']]
|
| | coco_class_name = [cat['name'] for cat in data['categories']]
|
| | coco_is_thing = [cat['isthing'] for cat in data['categories']]
|
| | self.coco_id_to_cont_id = {coco_id: cont_id for cont_id, coco_id in enumerate(coco_class_ids)}
|
| | self.coco_class_name = coco_class_name + ['background']
|
| | self.coco_is_thing = coco_is_thing
|
| |
|
| |
|
| | def __len__(self):
|
| | return len(self.data)
|
| |
|
| | def preprocess_multimodal(self, sources):
|
| | for source in sources:
|
| | for sentence in source:
|
| | if DEFAULT_IMAGE_TOKEN in sentence['value']:
|
| | sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
|
| | sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value']
|
| | sentence['value'] = sentence['value'].strip()
|
| | if "mmtag" in conversation_lib.default_conversation.version:
|
| | sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN,
|
| | '<Image>' + DEFAULT_IMAGE_TOKEN + '</Image>')
|
| |
|
| | if DEFAULT_SEG_TOKEN in sentence['value']:
|
| | sentence['value'] = sentence['value'].replace(DEFAULT_SEG_TOKEN, '').strip()
|
| | sentence['value'] = sentence['value'] + '\n' + DEFAULT_SEG_TOKEN
|
| | sentence['value'] = sentence['value']
|
| | return sources
|
| |
|
| | def preprocess_llama2(self, sources, tokenizer):
|
| | conv = conversation_lib.default_conversation.copy()
|
| | roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
| |
|
| |
|
| | conversations = []
|
| | for i, source in enumerate(sources):
|
| | if roles[source[0]["from"]] != conv.roles[0]:
|
| |
|
| | source = source[1:]
|
| |
|
| | conv.messages = []
|
| | for j, sentence in enumerate(source):
|
| | role = roles[sentence["from"]]
|
| | assert role == conv.roles[j % 2], f"{i}"
|
| | conv.append_message(role, sentence["value"])
|
| | conversations.append(conv.get_prompt())
|
| |
|
| |
|
| |
|
| | input_ids = torch.stack(
|
| | [self.tokenizer_special_tokens(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
|
| |
|
| | targets = input_ids.clone()
|
| |
|
| | assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
|
| |
|
| |
|
| | sep = "[/INST] "
|
| | for conversation, target in zip(conversations, targets):
|
| | total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
| |
|
| | rounds = conversation.split(conv.sep2)
|
| | cur_len = 1
|
| | target[:cur_len] = IGNORE_INDEX
|
| | for i, rou in enumerate(rounds):
|
| | if rou == "":
|
| | break
|
| |
|
| | parts = rou.split(sep)
|
| | if len(parts) != 2:
|
| | break
|
| | parts[0] += sep
|
| |
|
| | round_len = len(self.tokenizer_special_tokens(rou, tokenizer))
|
| | instruction_len = len(self.tokenizer_special_tokens(parts[0], tokenizer)) - 2
|
| |
|
| | target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
|
| |
|
| | cur_len += round_len
|
| | target[cur_len:] = IGNORE_INDEX
|
| |
|
| | if cur_len < tokenizer.model_max_length:
|
| | if cur_len != total_len:
|
| | target[:] = IGNORE_INDEX
|
| |
|
| | return dict(
|
| | input_ids=input_ids,
|
| | labels=targets,
|
| | )
|
| |
|
| | def tokenizer_special_tokens(self, prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX,
|
| | seg_token_index=SEG_TOKEN_INDEX, cls_token_index=CLS_TOKEN_INDEX,
|
| | region_token_index=REGION_TOKEN_INDEX, return_tensors=None):
|
| | input_ids = []
|
| | special_token_map = {'<image>': image_token_index, '<seg>': seg_token_index, '<cls>': cls_token_index, '<region>':region_token_index}
|
| | prompt_chunks = re.split('(<image>|<seg>|<cls>|<region>)', prompt)
|
| |
|
| | for chunk in prompt_chunks:
|
| | if chunk in special_token_map:
|
| | input_ids.append(special_token_map[chunk])
|
| | else:
|
| | input_ids.extend(tokenizer.encode(chunk, add_special_tokens=False))
|
| | if return_tensors is not None:
|
| | if return_tensors == 'pt':
|
| | return torch.tensor(input_ids, dtype=torch.long).squeeze()
|
| | raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
| | else:
|
| | return input_ids
|
| |
|
| | def preprocess_class_name(self, CLS_token='[CAT]'):
|
| | tokenized = [self.tokenizer.encode(class_name, add_special_tokens=False) for class_name in self.coco_class_name]
|
| | tokenized_class_names = [tokens + [self.tokenizer.encode(CLS_token, add_special_tokens=False)[0]] for tokens in
|
| | tokenized]
|
| | class_name_id = [token for sublist in tokenized_class_names for token in sublist]
|
| | class_name_id = torch.tensor(class_name_id)
|
| | cls_indices = [idx for idx, sublist in enumerate(tokenized_class_names) for _ in sublist]
|
| | cls_indices = torch.tensor(cls_indices)
|
| |
|
| | return class_name_id, cls_indices
|
| |
|
| | def __getitem__(self, idx):
|
| | data = self.data[idx]
|
| | image_id = int(data["image_id"])
|
| | image_file = os.path.join(self.panoptic_image_path, os.path.splitext(data["file_name"])[0] + ".jpg")
|
| |
|
| | data_dict = {}
|
| | data_dict['file_name'] = image_file
|
| | data_dict['image_id'] = image_id
|
| | label_file = os.path.join(self.panoptic_gt_path, data["file_name"])
|
| | sem_label_file = os.path.join(self.semantic_gt_path, data["file_name"])
|
| | data_dict['pan_seg_file_name'] = label_file
|
| | data_dict['sem_seg_file_name'] = sem_label_file
|
| | segments_info = data["segments_info"]
|
| | for seg in segments_info:
|
| | seg['category_id'] = self.coco_id_to_cont_id[seg['category_id']]
|
| | data_dict['segments_info'] = segments_info
|
| |
|
| | if isinstance(self.data_args.image_processor, dict):
|
| | processor = self.data_args.image_processor['panoptic']
|
| | else:
|
| | processor = self.data_args.image_processor
|
| | data_dict = processor.preprocess(data_dict, mask_format=self.mask_format)
|
| | instruction = 'Panoptic Segmentation: You need to segment all objects '
|
| | prefix_inst = 'This is an image <image>, Please do Panoptic Segmentation.'
|
| |
|
| | num_class = len(self.coco_class_name)
|
| | category = '<cls>, ' * (num_class-1) + '<cls>.'
|
| |
|
| | sources_value = f'\nThis is all the candidate categories: {category}\n'
|
| |
|
| | sources = [[{'from': 'human', 'value': prefix_inst + sources_value},
|
| | {'from': 'gpt', 'value': '\nSure, the segmentation result is <seg>'}]]
|
| |
|
| |
|
| | text_dict = self.preprocess_llama2(sources, self.tokenizer)
|
| | input_ids = text_dict['input_ids'][0]
|
| | labels = text_dict['labels'][0]
|
| |
|
| | class_name_ids, cls_indices = self.preprocess_class_name(CLS_token='[SEG]')
|
| | class_name_embedding_indices = torch.zeros_like(input_ids)
|
| | class_name_embedding_indices[input_ids == CLS_TOKEN_INDEX] = 1
|
| |
|
| | data_dict['input_ids'] = text_dict['input_ids'][0]
|
| | data_dict['labels'] = text_dict['labels'][0]
|
| |
|
| | data_dict['class_name_ids'] = class_name_ids
|
| | data_dict['cls_indices'] = cls_indices
|
| | data_dict['class_name_embedding_indices'] = class_name_embedding_indices
|
| | return data_dict
|
| |
|
| | class COCO_interactive_dataset_train(COCO_panoptic_dataset):
|
| | def __init__(self, json_path, tokenizer, data_args):
|
| | if isinstance(json_path, list):
|
| | data = []
|
| | for path in json_path:
|
| | with open(path) as f:
|
| | cur_data = json.load(f)
|
| | data.extend(cur_data)
|
| | else:
|
| | with open(json_path) as f:
|
| | data = json.load(f)
|
| |
|
| | self.data = data
|
| |
|
| |
|
| | if training_args.first_stage:
|
| | subset_size = len(self.data) // 20
|
| | self.data = random.sample(self.data, subset_size)
|
| | print('!!!!!!!!!!!!!!!!!!!!!!! Len of Stage1 Training;!!!!!!!!!!!!!!!!!!', len(self.data))
|
| |
|
| |
|
| | self.tokenizer = tokenizer
|
| | self.data_args = data_args
|
| | coco_class_ids = [
|
| | 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17,
|
| | 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
|
| | 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49,
|
| | 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
|
| | 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
|
| | 82, 84, 85, 86, 87, 88, 89, 90
|
| | ]
|
| | coco_class_name = [
|
| | 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
| | 'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
|
| | 'stop sign', 'parking meter', 'bench', 'bird', 'cat',
|
| | 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear',
|
| | 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
|
| | 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
|
| | 'sports ball', 'kite', 'baseball bat', 'baseball glove',
|
| | 'skateboard', 'surfboard', 'tennis racket', 'bottle',
|
| | 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
|
| | 'banana', 'apple', 'sandwich', 'orange', 'broccoli',
|
| | 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
|
| | 'couch', 'potted plant', 'bed', 'dining table', 'toilet',
|
| | 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
|
| | 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book',
|
| | 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
|
| | ]
|
| |
|
| | self.coco_id_to_cont_id = {coco_id: cont_id for cont_id, coco_id in enumerate(coco_class_ids)}
|
| | self.coco_class_name = coco_class_name + ['background']
|
| | def tokenizer_special_tokens(self, prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX,
|
| | seg_token_index=SEG_TOKEN_INDEX, cls_token_index=CLS_TOKEN_INDEX,
|
| | region_token_index=REGION_TOKEN_INDEX, return_tensors=None):
|
| | input_ids = []
|
| | special_token_map = {'<image>': image_token_index, '<seg>': seg_token_index, '<cls>': cls_token_index, '<region>':region_token_index}
|
| | prompt_chunks = re.split('(<image>|<seg>|<cls>|<region>)', prompt)
|
| |
|
| | for chunk in prompt_chunks:
|
| | if chunk in special_token_map:
|
| | input_ids.append(special_token_map[chunk])
|
| | else:
|
| | input_ids.extend(tokenizer.encode(chunk, add_special_tokens=False))
|
| | if return_tensors is not None:
|
| | if return_tensors == 'pt':
|
| | return torch.tensor(input_ids, dtype=torch.long).squeeze()
|
| | raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
| | else:
|
| | return input_ids
|
| |
|
| |
|
| | def preprocess_class_name(self, CLS_token='[CAT]'):
|
| | tokenized = [self.tokenizer.encode(class_name, add_special_tokens=False) for class_name in self.coco_class_name]
|
| | tokenized_class_names = [tokens + [self.tokenizer.encode(CLS_token, add_special_tokens=False)[0]] for tokens in
|
| | tokenized]
|
| |
|
| | class_name_id = [token for sublist in tokenized_class_names for token in sublist]
|
| | class_name_id = torch.tensor(class_name_id)
|
| | cls_indices = [idx for idx, sublist in enumerate(tokenized_class_names) for _ in sublist]
|
| | cls_indices = torch.tensor(cls_indices)
|
| |
|
| | return class_name_id, cls_indices
|
| | def __getitem__(self, idx):
|
| | data = self.data[idx]
|
| | image_file = data['image']
|
| | image_folder = self.data_args.image_folder
|
| |
|
| |
|
| | data_dict = {}
|
| | data_dict['file_name'] = os.path.join(image_folder, image_file)
|
| | data_dict['height'] = data['image_info']['height']
|
| | data_dict['width'] = data['image_info']['width']
|
| | data_dict['image_id'] = data['new_img_id']
|
| | data_dict['annotations'] = data['anns']
|
| | for annotation in data_dict['annotations']:
|
| | annotation['bbox_mode'] = BoxMode.XYXY_ABS
|
| |
|
| | if annotation['category_id'] in self.coco_id_to_cont_id:
|
| | annotation['category_id'] = self.coco_id_to_cont_id[annotation['category_id']]
|
| | elif annotation['category_id'] in self.coco_id_to_cont_id.values():
|
| | annotation['category_id'] = annotation['category_id']
|
| | else:
|
| | raise ValueError
|
| | annotation['image_id'] = data['new_img_id']
|
| |
|
| | if isinstance(self.data_args.image_processor,dict):
|
| | processor = self.data_args.image_processor['instance']
|
| | else:
|
| | processor = self.data_args.image_processor
|
| | region_mask_type = getattr(self.data_args,'region_mask_type',None)
|
| | if region_mask_type is not None:
|
| | region_mask_type = region_mask_type.split('||')
|
| | data_dict = processor.preprocess(data_dict,region_mask_type=region_mask_type)
|
| |
|
| | num_target = len(data_dict['instances'])
|
| | prefix_inst = 'This is an image <image>, Please segment by given regions'
|
| | regions_inst = ' <region>,' * (num_target - 1) + ' <region>.'
|
| | sources_value = f'\nThis is all regions: {regions_inst}\n'
|
| |
|
| | sources = [
|
| | [{'from': 'human', 'value': prefix_inst + sources_value},
|
| | {'from': 'gpt', 'value': '\n[SEG]<seg>'}]]
|
| |
|
| | text_dict = self.preprocess_llama2(sources, self.tokenizer)
|
| | input_ids = text_dict['input_ids'][0]
|
| | labels = text_dict['labels'][0]
|
| | data_dict['input_ids'] = input_ids
|
| | data_dict['labels'] = labels
|
| | data_dict['dataset_type'] = 'region_coco'
|
| |
|
| | return data_dict
|
| |
|
| | class COCO_interactive_dataset_eval(COCO_panoptic_dataset):
|
| | def __init__(self, data_list, tokenizer, data_args):
|
| |
|
| | data = data_list
|
| | self.data = data
|
| | self.tokenizer = tokenizer
|
| | self.data_args = data_args
|
| | coco_class_ids = [
|
| | 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17,
|
| | 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
|
| | 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49,
|
| | 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
|
| | 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
|
| | 82, 84, 85, 86, 87, 88, 89, 90
|
| | ]
|
| | coco_class_name = [
|
| | 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
| | 'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
|
| | 'stop sign', 'parking meter', 'bench', 'bird', 'cat',
|
| | 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear',
|
| | 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
|
| | 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
|
| | 'sports ball', 'kite', 'baseball bat', 'baseball glove',
|
| | 'skateboard', 'surfboard', 'tennis racket', 'bottle',
|
| | 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
|
| | 'banana', 'apple', 'sandwich', 'orange', 'broccoli',
|
| | 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
|
| | 'couch', 'potted plant', 'bed', 'dining table', 'toilet',
|
| | 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
|
| | 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book',
|
| | 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
|
| | ]
|
| |
|
| | self.coco_id_to_cont_id = {coco_id: cont_id for cont_id, coco_id in enumerate(coco_class_ids)}
|
| | self.coco_class_name = coco_class_name + ['background']
|
| | def tokenizer_special_tokens(self, prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX,
|
| | seg_token_index=SEG_TOKEN_INDEX, cls_token_index=CLS_TOKEN_INDEX,
|
| | region_token_index=REGION_TOKEN_INDEX, return_tensors=None):
|
| | input_ids = []
|
| | special_token_map = {'<image>': image_token_index, '<seg>': seg_token_index, '<cls>': cls_token_index, '<region>':region_token_index}
|
| | prompt_chunks = re.split('(<image>|<seg>|<cls>|<region>)', prompt)
|
| |
|
| | for chunk in prompt_chunks:
|
| | if chunk in special_token_map:
|
| | input_ids.append(special_token_map[chunk])
|
| | else:
|
| | input_ids.extend(tokenizer.encode(chunk, add_special_tokens=False))
|
| | if return_tensors is not None:
|
| | if return_tensors == 'pt':
|
| | return torch.tensor(input_ids, dtype=torch.long).squeeze()
|
| | raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
| | else:
|
| | return input_ids
|
| |
|
| | def preprocess_class_name(self, CLS_token='[CAT]'):
|
| | tokenized = [self.tokenizer.encode(class_name, add_special_tokens=False) for class_name in self.coco_class_name]
|
| | tokenized_class_names = [tokens + [self.tokenizer.encode(CLS_token, add_special_tokens=False)[0]] for tokens in
|
| | tokenized]
|
| |
|
| | class_name_id = [token for sublist in tokenized_class_names for token in sublist]
|
| | class_name_id = torch.tensor(class_name_id)
|
| | cls_indices = [idx for idx, sublist in enumerate(tokenized_class_names) for _ in sublist]
|
| | cls_indices = torch.tensor(cls_indices)
|
| |
|
| | return class_name_id, cls_indices
|
| | def __getitem__(self, idx):
|
| | data = self.data[idx]
|
| | image_file = data['image']
|
| | image_folder = self.data_args.image_folder
|
| |
|
| |
|
| | data_dict = {}
|
| | data_dict['file_name'] = os.path.join(image_folder, image_file)
|
| | data_dict['height'] = data['image_info']['height']
|
| | data_dict['width'] = data['image_info']['width']
|
| | data_dict['image_id'] = data['new_img_id']
|
| | data_dict['annotations'] = data['anns']
|
| | for annotation in data_dict['annotations']:
|
| | annotation['bbox_mode'] = BoxMode.XYXY_ABS
|
| |
|
| | if annotation['category_id'] in self.coco_id_to_cont_id:
|
| | annotation['category_id'] = self.coco_id_to_cont_id[annotation['category_id']]
|
| | elif annotation['category_id'] in self.coco_id_to_cont_id.values():
|
| | annotation['category_id'] = annotation['category_id']
|
| | else:
|
| | raise ValueError
|
| | annotation['image_id'] = data['new_img_id']
|
| |
|
| | if isinstance(self.data_args.image_processor,dict):
|
| | processor = self.data_args.image_processor['instance']
|
| | else:
|
| | processor = self.data_args.image_processor
|
| | region_mask_type = getattr(self.data_args,'region_mask_type',None)
|
| | if region_mask_type is not None:
|
| | region_mask_type = region_mask_type.split('||')
|
| | data_dict = processor.preprocess(data_dict,region_mask_type=region_mask_type)
|
| |
|
| | num_target = len(data_dict['instances'])
|
| | prefix_inst = 'This is an image <image>, Please segment by given regions'
|
| | regions_inst = ' <region>,' * (num_target - 1) + ' <region>.'
|
| | sources_value = f'\nThis is all regions: {regions_inst}\n'
|
| |
|
| | sources = [
|
| | [{'from': 'human', 'value': prefix_inst + sources_value},
|
| | {'from': 'gpt', 'value': '\n[SEG]<seg>'}]]
|
| |
|
| | text_dict = self.preprocess_llama2(sources, self.tokenizer)
|
| | input_ids = text_dict['input_ids'][0]
|
| | labels = text_dict['labels'][0]
|
| | data_dict['input_ids'] = input_ids
|
| | data_dict['labels'] = labels
|
| | data_dict['dataset_type'] = 'region_coco'
|
| |
|
| | return data_dict
|
| |
|
| | class COCO_interactive_dataset(COCO_panoptic_dataset):
|
| | def __init__(self, json_path, tokenizer, data_args):
|
| | if isinstance(json_path, list):
|
| | data = []
|
| | for path in json_path:
|
| | with open(path) as f:
|
| | cur_data = json.load(f)
|
| | data.extend(cur_data)
|
| | else:
|
| | with open(json_path) as f:
|
| | data = json.load(f)
|
| |
|
| | self.data = data
|
| |
|
| | self.tokenizer = tokenizer
|
| | self.data_args = data_args
|
| | coco_class_ids = [
|
| | 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17,
|
| | 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
|
| | 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49,
|
| | 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
|
| | 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
|
| | 82, 84, 85, 86, 87, 88, 89, 90
|
| | ]
|
| | coco_class_name = [
|
| | 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
| | 'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
|
| | 'stop sign', 'parking meter', 'bench', 'bird', 'cat',
|
| | 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear',
|
| | 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
|
| | 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
|
| | 'sports ball', 'kite', 'baseball bat', 'baseball glove',
|
| | 'skateboard', 'surfboard', 'tennis racket', 'bottle',
|
| | 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
|
| | 'banana', 'apple', 'sandwich', 'orange', 'broccoli',
|
| | 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
|
| | 'couch', 'potted plant', 'bed', 'dining table', 'toilet',
|
| | 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
|
| | 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book',
|
| | 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
|
| | ]
|
| |
|
| | self.coco_id_to_cont_id = {coco_id: cont_id for cont_id, coco_id in enumerate(coco_class_ids)}
|
| | self.coco_class_name = coco_class_name + ['background']
|
| | def tokenizer_special_tokens(self, prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX,
|
| | seg_token_index=SEG_TOKEN_INDEX, cls_token_index=CLS_TOKEN_INDEX,
|
| | region_token_index=REGION_TOKEN_INDEX, return_tensors=None):
|
| | input_ids = []
|
| | special_token_map = {'<image>': image_token_index, '<seg>': seg_token_index, '<cls>': cls_token_index, '<region>':region_token_index}
|
| | prompt_chunks = re.split('(<image>|<seg>|<cls>|<region>)', prompt)
|
| |
|
| | for chunk in prompt_chunks:
|
| | if chunk in special_token_map:
|
| | input_ids.append(special_token_map[chunk])
|
| | else:
|
| | input_ids.extend(tokenizer.encode(chunk, add_special_tokens=False))
|
| | if return_tensors is not None:
|
| | if return_tensors == 'pt':
|
| | return torch.tensor(input_ids, dtype=torch.long).squeeze()
|
| | raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
| | else:
|
| | return input_ids
|
| |
|
| |
|
| | def preprocess_class_name(self, CLS_token='[CAT]'):
|
| | tokenized = [self.tokenizer.encode(class_name, add_special_tokens=False) for class_name in self.coco_class_name]
|
| | tokenized_class_names = [tokens + [self.tokenizer.encode(CLS_token, add_special_tokens=False)[0]] for tokens in
|
| | tokenized]
|
| |
|
| | class_name_id = [token for sublist in tokenized_class_names for token in sublist]
|
| | class_name_id = torch.tensor(class_name_id)
|
| | cls_indices = [idx for idx, sublist in enumerate(tokenized_class_names) for _ in sublist]
|
| | cls_indices = torch.tensor(cls_indices)
|
| |
|
| | return class_name_id, cls_indices
|
| | def __getitem__(self, idx):
|
| | data = self.data[idx]
|
| | image_file = data['image']
|
| | image_folder = self.data_args.image_folder
|
| |
|
| |
|
| | data_dict = {}
|
| | data_dict['file_name'] = os.path.join(image_folder, image_file)
|
| | data_dict['height'] = data['image_info']['height']
|
| | data_dict['width'] = data['image_info']['width']
|
| | data_dict['image_id'] = data['new_img_id']
|
| | data_dict['annotations'] = data['anns']
|
| | for annotation in data_dict['annotations']:
|
| | annotation['bbox_mode'] = BoxMode.XYXY_ABS
|
| |
|
| | if annotation['category_id'] in self.coco_id_to_cont_id:
|
| | annotation['category_id'] = self.coco_id_to_cont_id[annotation['category_id']]
|
| | elif annotation['category_id'] in self.coco_id_to_cont_id.values():
|
| | annotation['category_id'] = annotation['category_id']
|
| | else:
|
| | raise ValueError
|
| | annotation['image_id'] = data['new_img_id']
|
| |
|
| | if isinstance(self.data_args.image_processor,dict):
|
| | processor = self.data_args.image_processor['instance']
|
| | else:
|
| | processor = self.data_args.image_processor
|
| | region_mask_type = getattr(self.data_args,'region_mask_type',None)
|
| | if region_mask_type is not None:
|
| | region_mask_type = region_mask_type.split('||')
|
| | data_dict = processor.preprocess(data_dict,region_mask_type=region_mask_type)
|
| |
|
| | num_target = len(data_dict['instances'])
|
| | prefix_inst = 'This is an image <image>, Please segment by given regions'
|
| | regions_inst = ' <region>,' * (num_target - 1) + ' <region>.'
|
| | sources_value = f'\nThis is all regions: {regions_inst}\n'
|
| |
|
| | sources = [
|
| | [{'from': 'human', 'value': prefix_inst + sources_value},
|
| | {'from': 'gpt', 'value': '\n[SEG]<seg>'}]]
|
| |
|
| | text_dict = self.preprocess_llama2(sources, self.tokenizer)
|
| | input_ids = text_dict['input_ids'][0]
|
| | labels = text_dict['labels'][0]
|
| | data_dict['input_ids'] = input_ids
|
| | data_dict['labels'] = labels
|
| | data_dict['dataset_type'] = 'region_coco'
|
| |
|
| | return data_dict
|
| |
|
| |
|
| | class COCO_instance_dataset(COCO_interactive_dataset_train):
|
| | def __init__(self, json_path, tokenizer, data_args):
|
| | if isinstance(json_path, list):
|
| | data = []
|
| | for path in json_path:
|
| | with open(path) as f:
|
| | cur_data = json.load(f)
|
| | data.extend(cur_data)
|
| | else:
|
| | with open(json_path) as f:
|
| | data = json.load(f)
|
| | self.data = data
|
| | self.tokenizer = tokenizer
|
| | self.data_args = data_args
|
| | self.mask_format = 'polygon'
|
| | coco_class_ids = [
|
| | 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17,
|
| | 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
|
| | 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49,
|
| | 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
|
| | 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
|
| | 82, 84, 85, 86, 87, 88, 89, 90
|
| | ]
|
| | coco_class_name = [
|
| | 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
| | 'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
|
| | 'stop sign', 'parking meter', 'bench', 'bird', 'cat',
|
| | 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear',
|
| | 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
|
| | 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
|
| | 'sports ball', 'kite', 'baseball bat', 'baseball glove',
|
| | 'skateboard', 'surfboard', 'tennis racket', 'bottle',
|
| | 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
|
| | 'banana', 'apple', 'sandwich', 'orange', 'broccoli',
|
| | 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
|
| | 'couch', 'potted plant', 'bed', 'dining table', 'toilet',
|
| | 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
|
| | 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book',
|
| | 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
|
| | ]
|
| | self.coco_id_to_cont_id = {coco_id: cont_id for cont_id, coco_id in enumerate(coco_class_ids)}
|
| | self.coco_class_name = coco_class_name + ['background']
|
| |
|
| | def tokenizer_special_tokens(self, prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX,
|
| | seg_token_index=SEG_TOKEN_INDEX, cls_token_index=CLS_TOKEN_INDEX,
|
| | region_token_index=REGION_TOKEN_INDEX, return_tensors=None):
|
| | input_ids = []
|
| | special_token_map = {'<image>': image_token_index, '<seg>': seg_token_index, '<cls>': cls_token_index, '<region>':region_token_index}
|
| | prompt_chunks = re.split('(<image>|<seg>|<cls>|<region>)', prompt)
|
| |
|
| | for chunk in prompt_chunks:
|
| | if chunk in special_token_map:
|
| | input_ids.append(special_token_map[chunk])
|
| | else:
|
| | input_ids.extend(tokenizer.encode(chunk, add_special_tokens=False))
|
| | if return_tensors is not None:
|
| | if return_tensors == 'pt':
|
| | return torch.tensor(input_ids, dtype=torch.long).squeeze()
|
| | raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
| | else:
|
| | return input_ids
|
| |
|
| | def preprocess_class_name(self, CLS_token='[CAT]'):
|
| | tokenized = [self.tokenizer.encode(class_name, add_special_tokens=False) for class_name in self.coco_class_name]
|
| | tokenized_class_names = [tokens + [self.tokenizer.encode(CLS_token, add_special_tokens=False)[0]] for tokens in
|
| | tokenized]
|
| |
|
| | class_name_id = [token for sublist in tokenized_class_names for token in sublist]
|
| | class_name_id = torch.tensor(class_name_id)
|
| | cls_indices = [idx for idx, sublist in enumerate(tokenized_class_names) for _ in sublist]
|
| | cls_indices = torch.tensor(cls_indices)
|
| |
|
| | return class_name_id, cls_indices
|
| |
|
| | def __getitem__(self, idx):
|
| | data = self.data[idx]
|
| | image_file = data['image']
|
| | image_folder = self.data_args.image_folder
|
| |
|
| | data_dict = {}
|
| | data_dict['file_name'] = os.path.join(image_folder, image_file)
|
| | data_dict['height'] = data['image_info']['height']
|
| | data_dict['width'] = data['image_info']['width']
|
| | data_dict['image_id'] = data['new_img_id']
|
| | data_dict['annotations'] = data['anns']
|
| | for annotation in data_dict['annotations']:
|
| | annotation['bbox_mode'] = BoxMode.XYXY_ABS
|
| | if annotation['category_id'] in self.coco_id_to_cont_id:
|
| | annotation['category_id'] = self.coco_id_to_cont_id[annotation['category_id']]
|
| | elif annotation['category_id'] in self.coco_id_to_cont_id.values():
|
| | annotation['category_id'] = annotation['category_id']
|
| | else:
|
| | raise ValueError
|
| | annotation['image_id'] = data['new_img_id']
|
| |
|
| | if isinstance(self.data_args.image_processor, dict):
|
| | processor = self.data_args.image_processor['instance']
|
| | else:
|
| | processor = self.data_args.image_processor
|
| | data_dict = processor.preprocess(data_dict, mask_format=self.mask_format)
|
| | data_dict['annotations'] = data['anns']
|
| |
|
| | instruction = 'Panoptic Segmentation: You need to segment all objects '
|
| | prefix_inst = 'This is an image <image>, Please do Panoptic Segmentation.'
|
| |
|
| | num_class = len(self.coco_class_name)
|
| | category = '<cls>, ' * (num_class - 1) + '<cls>.'
|
| |
|
| | sources_value = f'\nThis is all the candidate categories: {category}\n'
|
| |
|
| | sources = [[{'from': 'human', 'value': prefix_inst + sources_value},
|
| | {'from': 'gpt', 'value': '\nSure, the segmentation result is <seg>'}]]
|
| |
|
| |
|
| |
|
| | text_dict = self.preprocess_llama2(sources, self.tokenizer)
|
| | input_ids = text_dict['input_ids'][0]
|
| | labels = text_dict['labels'][0]
|
| |
|
| | class_name_ids, cls_indices = self.preprocess_class_name(CLS_token='[SEG]')
|
| | class_name_embedding_indices = torch.zeros_like(input_ids)
|
| | class_name_embedding_indices[input_ids == CLS_TOKEN_INDEX] = 1
|
| |
|
| | data_dict['input_ids'] = text_dict['input_ids'][0]
|
| | data_dict['labels'] = text_dict['labels'][0]
|
| |
|
| | data_dict['class_name_ids'] = class_name_ids
|
| | data_dict['cls_indices'] = cls_indices
|
| | data_dict['class_name_embedding_indices'] = class_name_embedding_indices
|
| | return data_dict
|
| |
|
| |
|
| |
|
| | class COCO_panoptic_dataset_random(COCO_panoptic_dataset):
|
| | def preprocess_class_name(self, CLS_token='[CAT]'):
|
| | random_idx = list(range(len(self.coco_class_name)))
|
| | random.shuffle(random_idx)
|
| | random_class_name = [self.coco_class_name[i] for i in random_idx]
|
| | permute_idx = list(sorted(range(len(random_idx)), key=random_idx.__getitem__))
|
| | tokenized = [self.tokenizer.encode(class_name, add_special_tokens=False) for class_name in random_class_name]
|
| | tokenized_class_names = [tokens + [self.tokenizer.encode(CLS_token, add_special_tokens=False)[0]] for tokens in
|
| | tokenized]
|
| | class_name_id = [token for sublist in tokenized_class_names for token in sublist]
|
| | class_name_id = torch.tensor(class_name_id)
|
| | cls_indices = [idx for idx, sublist in enumerate(tokenized_class_names) for _ in sublist]
|
| | cls_indices = torch.tensor(cls_indices)
|
| |
|
| | permute_idx = torch.tensor(permute_idx)
|
| |
|
| |
|
| | return class_name_id, cls_indices, permute_idx
|
| |
|
| | def __getitem__(self, idx):
|
| | data = self.data[idx]
|
| | image_id = int(data["image_id"])
|
| | image_file = os.path.join(self.panoptic_image_path, os.path.splitext(data["file_name"])[0] + ".jpg")
|
| |
|
| | data_dict = {}
|
| | data_dict['file_name'] = image_file
|
| | data_dict['image_id'] = image_id
|
| | label_file = os.path.join(self.panoptic_gt_path, data["file_name"])
|
| | sem_label_file = os.path.join(self.semantic_gt_path, data["file_name"])
|
| | data_dict['pan_seg_file_name'] = label_file
|
| | data_dict['sem_seg_file_name'] = sem_label_file
|
| | segments_info = data["segments_info"]
|
| | for seg in segments_info:
|
| | if seg['category_id'] in self.coco_id_to_cont_id:
|
| | seg['category_id'] = self.coco_id_to_cont_id[seg['category_id']]
|
| | elif seg['category_id'] in self.coco_id_to_cont_id.values():
|
| | seg['category_id'] = seg['category_id']
|
| | else:
|
| | raise ValueError
|
| | data_dict['segments_info'] = segments_info
|
| |
|
| |
|
| |
|
| | processor = self.data_args.image_processor['panoptic']
|
| | data_dict = processor.preprocess(data_dict, mask_format=self.mask_format)
|
| |
|
| | instruction = 'Panoptic Segmentation: You need to segment all objects '
|
| |
|
| | num_class = len(self.coco_class_name)
|
| | category = '<cls>, ' * (num_class-1) + '<cls>.'
|
| |
|
| | sources_value = f'This is all the candidate categories: {category}\n<image>\n'
|
| |
|
| | sources = [[{'from': 'human', 'value': sources_value + instruction},
|
| | {'from': 'gpt', 'value': '\n[SEG]<seg>'}]]
|
| |
|
| |
|
| | text_dict = self.preprocess_llama2(sources, self.tokenizer)
|
| | input_ids = text_dict['input_ids'][0]
|
| | labels = text_dict['labels'][0]
|
| |
|
| | class_name_ids, cls_indices, random_idx = self.preprocess_class_name()
|
| | data_dict['random_idx'] = random_idx
|
| | class_name_embedding_indices = torch.zeros_like(input_ids)
|
| | class_name_embedding_indices[input_ids == CLS_TOKEN_INDEX] = 1
|
| |
|
| | data_dict['input_ids'] = text_dict['input_ids'][0]
|
| | data_dict['labels'] = text_dict['labels'][0]
|
| | data_dict['dataset_type'] = 'panoptic_coco'
|
| |
|
| | data_dict['class_name_ids'] = class_name_ids
|
| | data_dict['cls_indices'] = cls_indices
|
| | data_dict['class_name_embedding_indices'] = class_name_embedding_indices
|
| | return data_dict
|
| |
|
| | class COCO_semantic_dataset(COCO_panoptic_dataset):
|
| | def __getitem__(self, idx):
|
| | data = self.data[idx]
|
| | image_id = int(data["image_id"])
|
| | image_file = os.path.join(self.panoptic_image_path, os.path.splitext(data["file_name"])[0] + ".jpg")
|
| |
|
| | data_dict = {}
|
| | data_dict['file_name'] = image_file
|
| | data_dict['image_id'] = image_id
|
| | label_file = os.path.join(self.panoptic_gt_path, data["file_name"])
|
| | sem_label_file = os.path.join(self.semantic_gt_path, data["file_name"])
|
| | data_dict['pan_seg_file_name'] = sem_label_file
|
| | data_dict['sem_seg_file_name'] = sem_label_file
|
| | segments_info = data["segments_info"]
|
| | for seg in segments_info:
|
| | seg['category_id'] = self.coco_id_to_cont_id[seg['category_id']]
|
| | data_dict['segments_info'] = segments_info
|
| |
|
| | if isinstance(self.data_args.image_processor, dict):
|
| | processor = self.data_args.image_processor['panoptic']
|
| | else:
|
| | processor = self.data_args.image_processor
|
| | data_dict = processor.preprocess(data_dict, mask_format=self.mask_format)
|
| |
|
| | instruction = 'Panoptic Segmentation: You need to segment all objects '
|
| | prefix_inst = 'This is an image <image>, Please do Semantic Segmentation.'
|
| |
|
| | num_class = len(self.coco_class_name)
|
| | category = '<cls>, ' * (num_class-1) + '<cls>.'
|
| |
|
| | sources_value = f'\nThis is all the candidate categories: {category}\n'
|
| |
|
| | sources = [[{'from': 'human', 'value': prefix_inst + sources_value},
|
| | {'from': 'gpt', 'value': '\nSure, the segmentation result is <seg>'}]]
|
| |
|
| |
|
| | text_dict = self.preprocess_llama2(sources, self.tokenizer)
|
| | input_ids = text_dict['input_ids'][0]
|
| | labels = text_dict['labels'][0]
|
| |
|
| | class_name_ids, cls_indices = self.preprocess_class_name(CLS_token='[SEG]')
|
| | class_name_embedding_indices = torch.zeros_like(input_ids)
|
| | class_name_embedding_indices[input_ids == CLS_TOKEN_INDEX] = 1
|
| |
|
| | data_dict['input_ids'] = text_dict['input_ids'][0]
|
| | data_dict['labels'] = text_dict['labels'][0]
|
| |
|
| | data_dict['class_name_ids'] = class_name_ids
|
| | data_dict['cls_indices'] = cls_indices
|
| | data_dict['class_name_embedding_indices'] = class_name_embedding_indices
|
| | return data_dict
|
| |
|
| |
|
| |
|
| | class RefCOCO_dataset(COCO_instance_dataset):
|
| |
|
| | def preprocess_referring_instruction(self,instruction, REFER_token='[SEG]'):
|
| | tokenized = self.tokenizer.encode(instruction, add_special_tokens=False)
|
| | tokenized = tokenized + [self.tokenizer.encode(REFER_token, add_special_tokens=False)[0]]
|
| |
|
| | token_refer_id = torch.tensor(tokenized)
|
| |
|
| | return token_refer_id
|
| | def tokenizer_special_tokens(self, prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX,
|
| | seg_token_index=SEG_TOKEN_INDEX, cls_token_index=CLS_TOKEN_INDEX,
|
| | region_token_index=REGION_TOKEN_INDEX,refer_token_index=REFER_TOKEN_INDEX, return_tensors=None):
|
| | input_ids = []
|
| | special_token_map = {'<image>': image_token_index, '<seg>': seg_token_index, '<cls>': cls_token_index, '<region>':region_token_index, '<refer>':refer_token_index}
|
| | prompt_chunks = re.split('(<image>|<seg>|<cls>|<region>|<refer>)', prompt)
|
| |
|
| | for chunk in prompt_chunks:
|
| | if chunk in special_token_map:
|
| | input_ids.append(special_token_map[chunk])
|
| | else:
|
| | input_ids.extend(tokenizer.encode(chunk, add_special_tokens=False))
|
| | if return_tensors is not None:
|
| | if return_tensors == 'pt':
|
| | return torch.tensor(input_ids, dtype=torch.long).squeeze()
|
| | raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
| | else:
|
| | return input_ids
|
| | def __getitem__(self, idx):
|
| | data = self.data[idx]
|
| | image_file = data['image_info']['file_name']
|
| | image_folder = self.data_args.refcoco_image_folder
|
| |
|
| | data_dict = {}
|
| | data_dict['file_name'] = os.path.join(image_folder, image_file)
|
| | data_dict['height'] = data['image_info']['height']
|
| | data_dict['width'] = data['image_info']['width']
|
| | data_dict['image_id'] = data['new_img_id']
|
| | data_dict['annotations'] = data['anns']
|
| | for annotation in data_dict['annotations']:
|
| | annotation['bbox_mode'] = BoxMode.XYXY_ABS
|
| |
|
| | if annotation['category_id'] in self.coco_id_to_cont_id:
|
| | annotation['category_id'] = self.coco_id_to_cont_id[annotation['category_id']]
|
| | elif annotation['category_id'] in self.coco_id_to_cont_id.values():
|
| | annotation['category_id'] = annotation['category_id']
|
| | else:
|
| | raise ValueError
|
| | annotation['image_id'] = data['new_img_id']
|
| |
|
| | if isinstance(self.data_args.image_processor,dict):
|
| | processor = self.data_args.image_processor['instance']
|
| | else:
|
| | processor = self.data_args.image_processor
|
| | data_dict = processor.preprocess(data_dict, mask_format=self.mask_format)
|
| |
|
| | sentences = data['instruction']
|
| |
|
| | prefix_inst = 'This is an image <image>, Please doing Referring Segmentation according to the following instruction:'
|
| | instruction = ''
|
| | for sent in sentences:
|
| | instruction += ' {}.'.format(sent['sent'])
|
| | sources = [[{'from': 'human', 'value': prefix_inst + '\n<refer>'},
|
| | {'from': 'gpt', 'value': '\nSure, the segmentation result is <seg>'}]]
|
| |
|
| | text_dict = self.preprocess_llama2(sources, self.tokenizer)
|
| | input_ids = text_dict['input_ids'][0]
|
| | labels = text_dict['labels'][0]
|
| |
|
| | token_refer_id = self.preprocess_referring_instruction(instruction)
|
| | refer_embedding_indices = torch.zeros_like(input_ids)
|
| | refer_embedding_indices[input_ids == REFER_TOKEN_INDEX] = 1
|
| |
|
| | data_dict['input_ids'] = text_dict['input_ids'][0]
|
| | data_dict['labels'] = text_dict['labels'][0]
|
| | data_dict['dataset_type'] = 'referring_coco'
|
| |
|
| | data_dict['token_refer_id'] = token_refer_id
|
| | data_dict['refer_embedding_indices'] = refer_embedding_indices
|
| | return data_dict
|
| |
|
| | def preprocess_multimodal(
|
| | sources,
|
| | data_args
|
| | ):
|
| | is_multimodal = data_args.is_multimodal
|
| | if not is_multimodal:
|
| | return sources
|
| |
|
| | for source in sources:
|
| | for sentence in source:
|
| | if DEFAULT_IMAGE_TOKEN in sentence['value']:
|
| | sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
|
| | sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value']
|
| | sentence['value'] = sentence['value'].strip()
|
| | if "mmtag" in conversation_lib.default_conversation.version:
|
| | sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN,
|
| | '<Image>' + DEFAULT_IMAGE_TOKEN + '</Image>')
|
| | replace_token = DEFAULT_IMAGE_TOKEN
|
| | if data_args.mm_use_im_start_end:
|
| | replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
|
| | sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
| |
|
| | return sources
|
| |
|
| | class UnifyDatasetSingleDatasetForBatch(Dataset):
|
| | """
|
| | Dataset to concatenate multiple datasets.
|
| | Purpose: useful to assemble different existing datasets, possibly
|
| | large-scale datasets as the concatenation operation is done in an
|
| | on-the-fly manner.
|
| | Arguments:
|
| | datasets (sequence): List of datasets to be concatenated
|
| | """
|
| |
|
| | @staticmethod
|
| | def cumsum(sequence):
|
| | r, s = [], 0
|
| | for e in sequence:
|
| | l = len(e)
|
| | r.append(l + s)
|
| | s += l
|
| | return r
|
| |
|
| |
|
| | def __init__(self,datasets,dataset_ratio,bs,fix_dataset_len=0):
|
| | super(UnifyDatasetSingleDatasetForBatch, self).__init__()
|
| | assert len(datasets) > 0, 'datasets should not be an empty iterable'
|
| | self.fix_dataset_len = fix_dataset_len
|
| |
|
| | self.cnt = 0
|
| | self.bs = bs
|
| |
|
| | self.datasets = list(datasets)
|
| | self.datasets_index_list = list(range(len(datasets)))
|
| | self.dataset_ratio = dataset_ratio
|
| | self.cur_dataset_index=0
|
| | self.dataset_length = [len(data) for data in self.datasets]
|
| | self.cumulative_sizes = self.cumsum(self.datasets)
|
| | self.coco_id_to_cont_id = {}
|
| | self.coco_class_name = {}
|
| | for _dataset in self.datasets:
|
| | dataset_coco_id_to_cont_id = _dataset.coco_id_to_cont_id if hasattr(_dataset,'coco_id_to_cont_id') else []
|
| | if len(dataset_coco_id_to_cont_id) > len(self.coco_id_to_cont_id):
|
| | self.coco_id_to_cont_id = dataset_coco_id_to_cont_id
|
| | for _dataset in self.datasets:
|
| | _dataset.coco_id_to_cont_id = self.coco_id_to_cont_id
|
| | for _dataset in self.datasets:
|
| | dataset_coco_class_name = _dataset.coco_class_name if hasattr(_dataset,'coco_class_name') else []
|
| | if len(dataset_coco_class_name) > len(self.coco_class_name):
|
| | self.coco_class_name = dataset_coco_class_name
|
| | for _dataset in self.datasets:
|
| | _dataset.coco_class_name = self.coco_class_name
|
| |
|
| |
|
| |
|
| | def update_dataset_index(self):
|
| | tempt = self.cur_dataset_index
|
| | tempt += 1
|
| | tempt = tempt % len(self.datasets)
|
| | self.cur_dataset_index = tempt
|
| |
|
| | def __len__(self):
|
| | if self.fix_dataset_len == 0:
|
| | return self.cumulative_sizes[-1]
|
| | else:
|
| | return self.fix_dataset_len
|
| |
|
| |
|
| | def __getitem__(self, idx):
|
| | cur_dataset_len = self.dataset_length[self.cur_dataset_index]
|
| | data_idx = idx % cur_dataset_len
|
| | output_data = self.datasets[self.cur_dataset_index][data_idx]
|
| | self.cnt += 1
|
| | if self.cnt == self.bs:
|
| | self.cnt = 0
|
| | self.update_dataset_index()
|
| | return output_data
|
| |
|
| |
|
| |
|
| | class MM_Conv_Dataset(Dataset):
|
| | def __init__(self, data_path,
|
| | tokenizer,
|
| | data_args):
|
| | super(MM_Conv_Dataset, self).__init__()
|
| | list_data_dict = json.load(open(data_path, "r"))
|
| |
|
| | print("Formatting inputs...Skip in lazy mode")
|
| | self.tokenizer = tokenizer
|
| | self.list_data_dict = list_data_dict
|
| | self.data_args = data_args
|
| |
|
| | def __len__(self):
|
| | return len(self.list_data_dict)
|
| |
|
| | def preprocess_llama2(self, sources, tokenizer):
|
| | conv = conversation_lib.default_conversation.copy()
|
| | roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
| |
|
| |
|
| | conversations = []
|
| | for i, source in enumerate(sources):
|
| | if roles[source[0]["from"]] != conv.roles[0]:
|
| |
|
| | source = source[1:]
|
| |
|
| | conv.messages = []
|
| | for j, sentence in enumerate(source):
|
| | role = roles[sentence["from"]]
|
| | assert role == conv.roles[j % 2], f"{i}"
|
| | conv.append_message(role, sentence["value"])
|
| | conversations.append(conv.get_prompt())
|
| |
|
| |
|
| |
|
| | input_ids = torch.stack(
|
| | [self.tokenizer_special_tokens(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
|
| |
|
| | targets = input_ids.clone()
|
| |
|
| | assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
|
| |
|
| |
|
| | sep = "[/INST] "
|
| | idx = 0
|
| | for conversation, target in zip(conversations, targets):
|
| | total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
| |
|
| | rounds = conversation.split(conv.sep2)
|
| | if conv.version == 'phi':
|
| | cur_len = 0
|
| | target[:cur_len] = IGNORE_INDEX
|
| | idx = 0
|
| | for i, rou in enumerate(rounds):
|
| | if rou == "":
|
| | continue
|
| |
|
| | parts = rou.split(sep)
|
| | if len(parts) != 2:
|
| | break
|
| | parts[0] += sep
|
| | if idx > 0:
|
| | round_len = len(self.tokenizer_special_tokens(rou, tokenizer)) + 2
|
| | else:
|
| | round_len = len(self.tokenizer_special_tokens(rou, tokenizer)) + 1
|
| | if idx > 0:
|
| | instruction_len = len(self.tokenizer_special_tokens(parts[0], tokenizer))
|
| | else:
|
| | instruction_len = len(self.tokenizer_special_tokens(parts[0], tokenizer)) - 1
|
| |
|
| | target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
|
| |
|
| | cur_len += round_len
|
| | idx += 1
|
| | target[cur_len:] = IGNORE_INDEX
|
| | else:
|
| | cur_len = 1
|
| | target[:cur_len] = IGNORE_INDEX
|
| | for i, rou in enumerate(rounds):
|
| | if rou == "":
|
| | continue
|
| |
|
| | parts = rou.split(sep)
|
| | if len(parts) != 2:
|
| | break
|
| | parts[0] += sep
|
| | round_len = len(self.tokenizer_special_tokens(rou, tokenizer))
|
| | instruction_len = len(self.tokenizer_special_tokens(parts[0], tokenizer)) - 2
|
| |
|
| | target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
|
| |
|
| | cur_len += round_len
|
| | idx += 1
|
| | target[cur_len:] = IGNORE_INDEX
|
| |
|
| | if cur_len < tokenizer.model_max_length:
|
| | if cur_len != total_len:
|
| | target[:] = IGNORE_INDEX
|
| | return dict(
|
| | input_ids=input_ids,
|
| | labels=targets,
|
| | )
|
| |
|
| |
|
| | def tokenizer_special_tokens(self, prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX,
|
| | seg_token_index=SEG_TOKEN_INDEX, return_tensors=None):
|
| | prompt_chunks = []
|
| | special_tokens = []
|
| | image_splits = prompt.split('<image>')
|
| |
|
| | for i, chunk in enumerate(image_splits):
|
| | if i != 0:
|
| | special_tokens.append('<image>')
|
| | seg_splits = chunk.split('<seg>')
|
| | prompt_chunks.extend(seg_splits)
|
| | special_tokens.extend(['<seg>'] * (len(seg_splits)-1))
|
| | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt_chunks]
|
| | special_indexes = [image_token_index if token == '<image>' else seg_token_index for token in special_tokens]
|
| | input_ids = []
|
| | for i, chunk in enumerate(prompt_chunks):
|
| | input_ids.extend(chunk)
|
| | if i != len(prompt_chunks) -1:
|
| | input_ids.extend([special_indexes[i]])
|
| | if return_tensors is not None:
|
| | if return_tensors == 'pt':
|
| | return torch.tensor(input_ids, dtype=torch.long).squeeze()
|
| | raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
| | return input_ids
|
| | def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
| | sources = self.list_data_dict[i]
|
| |
|
| |
|
| | sources = [sources]
|
| | assert len(sources) == 1, "Don't know why it is wrapped to a list"
|
| | data_dict = {}
|
| | if 'image' in sources[0]:
|
| | image_file = self.list_data_dict[i]['image']
|
| | image_folder = self.data_args.mmconv_path
|
| | if isinstance(self.data_args.image_processor, dict):
|
| | processor = self.data_args.image_processor['instance']
|
| | else:
|
| | processor = self.data_args.image_processor
|
| | if 'coco' in image_file:
|
| | image_folder = self.data_args.image_folder
|
| | image_file = os.path.basename(image_file)
|
| | data_dict['file_name'] = os.path.join(image_folder, image_file)
|
| | else:
|
| | data_dict['file_name'] = os.path.join(image_folder, image_file)
|
| | data_dict = processor.preprocess(data_dict)
|
| |
|
| | sources = preprocess_multimodal(
|
| | copy.deepcopy([e["conversations"] for e in sources]),
|
| | self.data_args)
|
| | else:
|
| | sources = copy.deepcopy([e["conversations"] for e in sources])
|
| | text_dict = self.preprocess_llama2(sources, self.tokenizer)
|
| | data_dict['input_ids'] = text_dict['input_ids'][0]
|
| | data_dict['labels'] = text_dict['labels'][0]
|
| | data_dict['dataset_type'] = 'mm_conv'
|
| | if 'image' not in data_dict:
|
| |
|
| | crop_size = 1024
|
| | data_dict['image'] = torch.zeros(3, crop_size, crop_size)
|
| | return data_dict
|
| |
|
| |
|
| |
|
| |
|
| | @dataclass
|
| | class DataCollatorForCOCODatasetV2(object):
|
| | """Collate examples for supervised fine-tuning."""
|
| |
|
| | tokenizer: transformers.PreTrainedTokenizer
|
| |
|
| | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
| | input_ids, labels = tuple([instance[key] for instance in instances]
|
| | for key in ("input_ids", "labels"))
|
| | input_ids = torch.nn.utils.rnn.pad_sequence(
|
| | input_ids,
|
| | batch_first=True,
|
| | padding_value=self.tokenizer.pad_token_id)
|
| | labels = torch.nn.utils.rnn.pad_sequence(labels,
|
| | batch_first=True,
|
| | padding_value=IGNORE_INDEX)
|
| | input_ids = input_ids[:, :self.tokenizer.model_max_length]
|
| | labels = labels[:, :self.tokenizer.model_max_length]
|
| | batch = dict(
|
| | input_ids=input_ids,
|
| | labels=labels,
|
| | attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
| | )
|
| | if 'image' in instances[0]:
|
| | images = [instance['image'] for instance in instances]
|
| | if all(x is not None and x.shape == images[0].shape for x in images):
|
| | batch['images'] = torch.stack(images)
|
| | else:
|
| | batch['images'] = images
|
| |
|
| | if 'vp_image' in instances[0]:
|
| | vp_images = [instance['vp_image'] for instance in instances]
|
| | if all(x is not None and x.shape == vp_images[0].shape for x in vp_images):
|
| | batch['vp_images'] = torch.stack(vp_images)
|
| | else:
|
| | batch['vp_images'] = vp_images
|
| | for instance in instances:
|
| | for key in ['input_ids', 'labels', 'image']:
|
| | del instance[key]
|
| | batch['seg_info'] = [instance for instance in instances]
|
| |
|
| | if 'dataset_type' in instances[0]:
|
| | batch['dataset_type'] = [instance['dataset_type'] for instance in instances]
|
| |
|
| | if 'class_name_ids' in instances[0]:
|
| | class_name_ids = [instance['class_name_ids'] for instance in instances]
|
| | if any(x.shape != class_name_ids[0].shape for x in class_name_ids):
|
| | batch['class_name_ids'] = torch.nn.utils.rnn.pad_sequence(
|
| | class_name_ids,
|
| | batch_first=True,
|
| | padding_value=-1,
|
| | )
|
| | else:
|
| | batch['class_name_ids'] = torch.stack(class_name_ids, dim=0)
|
| | if 'token_refer_id' in instances[0]:
|
| | token_refer_id = [instance['token_refer_id'] for instance in instances]
|
| | batch['token_refer_id'] = token_refer_id
|
| | if 'cls_indices' in instances[0]:
|
| | cls_indices = [instance['cls_indices'] for instance in instances]
|
| | if any(x.shape != cls_indices[0].shape for x in cls_indices):
|
| | batch['cls_indices'] = torch.nn.utils.rnn.pad_sequence(
|
| | cls_indices,
|
| | batch_first=True,
|
| | padding_value=-1,
|
| | )
|
| | else:
|
| | batch['cls_indices'] = torch.stack(cls_indices, dim=0)
|
| | if 'random_idx' in instances[0]:
|
| | random_idxs = [instance['random_idx'] for instance in instances]
|
| | batch['random_idx'] = torch.stack(random_idxs, dim=0)
|
| | if 'class_name_embedding_indices' in instances[0]:
|
| | class_name_embedding_indices = [instance['class_name_embedding_indices'] for instance in instances]
|
| | class_name_embedding_indices = torch.nn.utils.rnn.pad_sequence(
|
| | class_name_embedding_indices,
|
| | batch_first=True,
|
| | padding_value=0)
|
| | batch['class_name_embedding_indices'] = class_name_embedding_indices
|
| | if 'refer_embedding_indices' in instances[0]:
|
| | refer_embedding_indices = [instance['refer_embedding_indices'] for instance in instances]
|
| | refer_embedding_indices = torch.nn.utils.rnn.pad_sequence(
|
| | refer_embedding_indices,
|
| | batch_first=True,
|
| | padding_value=0)
|
| | batch['refer_embedding_indices'] = refer_embedding_indices
|
| |
|
| | return batch
|
| |
|
| |
|
| |
|
| |
|
| | @dataclass
|
| | class DataCollatorForCOCODatasetV2_old(object):
|
| | """Collate examples for supervised fine-tuning."""
|
| |
|
| | tokenizer: transformers.PreTrainedTokenizer
|
| |
|
| | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
| | input_ids, labels = tuple([instance[key] for instance in instances]
|
| | for key in ("input_ids", "labels"))
|
| | input_ids = torch.nn.utils.rnn.pad_sequence(
|
| | input_ids,
|
| | batch_first=True,
|
| | padding_value=self.tokenizer.pad_token_id)
|
| | labels = torch.nn.utils.rnn.pad_sequence(labels,
|
| | batch_first=True,
|
| | padding_value=IGNORE_INDEX)
|
| | input_ids = input_ids[:, :self.tokenizer.model_max_length]
|
| | labels = labels[:, :self.tokenizer.model_max_length]
|
| | batch = dict(
|
| | input_ids=input_ids,
|
| | labels=labels,
|
| | attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
| | )
|
| | if 'image' in instances[0]:
|
| | images = [instance['image'] for instance in instances]
|
| | if all(x is not None and x.shape == images[0].shape for x in images):
|
| | batch['images'] = torch.stack(images)
|
| | else:
|
| | batch['images'] = images
|
| | for instance in instances:
|
| | for key in ['input_ids', 'labels', 'image']:
|
| | del instance[key]
|
| | batch['seg_info'] = [instance for instance in instances]
|
| |
|
| | if 'dataset_type' in instances[0]:
|
| | batch['dataset_type'] = [instance['dataset_type'] for instance in instances]
|
| |
|
| | if 'class_name_ids' in instances[0]:
|
| | class_name_ids = [instance['class_name_ids'] for instance in instances]
|
| | if any(x.shape != class_name_ids[0].shape for x in class_name_ids):
|
| | batch['class_name_ids'] = torch.nn.utils.rnn.pad_sequence(
|
| | class_name_ids,
|
| | batch_first=True,
|
| | padding_value=-1,
|
| | )
|
| | else:
|
| | batch['class_name_ids'] = torch.stack(class_name_ids, dim=0)
|
| | if 'token_refer_id' in instances[0]:
|
| | token_refer_id = [instance['token_refer_id'] for instance in instances]
|
| | batch['token_refer_id'] = token_refer_id
|
| | if 'cls_indices' in instances[0]:
|
| | cls_indices = [instance['cls_indices'] for instance in instances]
|
| | if any(x.shape != cls_indices[0].shape for x in cls_indices):
|
| | batch['cls_indices'] = torch.nn.utils.rnn.pad_sequence(
|
| | cls_indices,
|
| | batch_first=True,
|
| | padding_value=-1,
|
| | )
|
| | else:
|
| | batch['cls_indices'] = torch.stack(cls_indices, dim=0)
|
| | if 'random_idx' in instances[0]:
|
| | random_idxs = [instance['random_idx'] for instance in instances]
|
| | batch['random_idx'] = torch.stack(random_idxs, dim=0)
|
| | if 'class_name_embedding_indices' in instances[0]:
|
| | class_name_embedding_indices = [instance['class_name_embedding_indices'] for instance in instances]
|
| | class_name_embedding_indices = torch.nn.utils.rnn.pad_sequence(
|
| | class_name_embedding_indices,
|
| | batch_first=True,
|
| | padding_value=0)
|
| | batch['class_name_embedding_indices'] = class_name_embedding_indices
|
| | if 'refer_embedding_indices' in instances[0]:
|
| | refer_embedding_indices = [instance['refer_embedding_indices'] for instance in instances]
|
| | refer_embedding_indices = torch.nn.utils.rnn.pad_sequence(
|
| | refer_embedding_indices,
|
| | batch_first=True,
|
| | padding_value=0)
|
| | batch['refer_embedding_indices'] = refer_embedding_indices
|
| |
|
| | return batch
|
| |
|