| | import copy |
| | import random |
| | import glob |
| | import json |
| | import logging |
| | import os |
| | import torch |
| |
|
| | from mmengine import print_log |
| | from mmengine.config import Config, ConfigDict |
| | from PIL import Image |
| | from torch.utils.data import Dataset |
| | import numpy as np |
| | import torch.nn.functional as F |
| | from pycocotools.coco import COCO |
| |
|
| | from xtuner.registry import BUILDER |
| |
|
| | from xtuner.dataset.utils import encode_fn |
| | from xtuner.dataset.map_fns import llava_map_fn |
| |
|
| | from projects.glamm.datasets.utils.utils import expand2square |
| |
|
| | from projects.glamm.datasets.utils.utils import SEG_QUESTIONS, ANSWER_LIST |
| | from projects.glamm.utils import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN |
| |
|
| |
|
| | class SemanticSegDataset(Dataset): |
| | def __init__(self, |
| | image_folder, |
| | image_processor, |
| | data_path=None, |
| | tokenizer=None, |
| | offline_processed_text_folder=None, |
| | max_dataset_length=None, |
| | dataset_map_fn=None, |
| | template_map_fn=None, |
| | max_length=2048, |
| | pad_image_to_square=False, |
| | num_proc=8, |
| | lazy=False, |
| | repeats=1, |
| | gcg_format=False, |
| | num_classes_per_sample=3, |
| | extra_image_processor=None): |
| | super().__init__() |
| | self.gcg_format = gcg_format |
| | if extra_image_processor is not None: |
| | self.extra_image_processor = BUILDER.build(extra_image_processor) |
| | self.num_classes_per_sample = num_classes_per_sample |
| | self.tokenizer = BUILDER.build(tokenizer) |
| |
|
| | self.tokenizer.add_tokens( |
| | [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True |
| | ) |
| | reg_tokens = ['<bbox>', '<point>'] |
| | segmentation_tokens = ['[SEG]'] |
| | phrase_tokens = ['<p>', '</p>'] |
| | special_tokens = reg_tokens + segmentation_tokens + phrase_tokens |
| | self.tokenizer.add_tokens(special_tokens, special_tokens=True) |
| |
|
| | assert offline_processed_text_folder or (data_path and tokenizer) |
| | self.lazy = lazy |
| |
|
| | self.max_length = max_length |
| | self.dataset_map_fn = dataset_map_fn |
| | self.template_map_fn = template_map_fn |
| | if isinstance(self.template_map_fn, dict) and self.lazy: |
| | _type = self.template_map_fn['type'] |
| | del self.template_map_fn['type'] |
| | self.template_map_fn = _type(**self.template_map_fn) |
| |
|
| | if offline_processed_text_folder and data_path: |
| | print_log( |
| | 'Both `offline_processed_text_folder` and ' |
| | '`data_path` are set, and we load dataset from' |
| | '`offline_processed_text_folder` ' |
| | f'({offline_processed_text_folder})', |
| | logger='current', |
| | level=logging.WARNING) |
| |
|
| | if offline_processed_text_folder is not None: |
| | raise NotImplementedError |
| | else: |
| | self.image_label_datas = self.json_file_preprocess(data_path, image_folder) |
| |
|
| | self.image_folder = image_folder |
| |
|
| | if isinstance(image_processor, dict) or isinstance(image_processor, Config) or isinstance(image_processor, ConfigDict): |
| | self.image_processor = BUILDER.build(image_processor) |
| | else: |
| | self.image_processor = image_processor |
| |
|
| | size = self.image_processor.crop_size |
| |
|
| | if isinstance(size, dict): |
| | self.image_w, self.image_h = size['width'], size['height'] |
| | elif isinstance(size, int): |
| | self.image_h, self.image_w = size, size |
| | else: |
| | self.image_w, self.image_h = size |
| |
|
| | self.pad_image_to_square = pad_image_to_square |
| | self.down_ratio = 1 |
| | self.repeats = repeats |
| |
|
| | def json_file_preprocess(self, data_path, image_folder): |
| | |
| | with open(data_path, 'r') as file: |
| | ade20k_classes = json.load(file) |
| | ade20k_image_dir = image_folder |
| | ade20k_images = [os.path.join(ade20k_image_dir, img) for img in os.listdir(ade20k_image_dir) if |
| | img.endswith('.jpg')] |
| | ade20k_labels = [img.replace(".jpg", ".png").replace( |
| | "images", "annotations") for img in ade20k_images] |
| | self.classes = np.array(ade20k_classes) |
| |
|
| | ret = [] |
| | for image, label in zip(ade20k_images, ade20k_labels): |
| | ret.append({"image": image, "label": label}) |
| | return ret |
| |
|
| | def __len__(self): |
| | return len(self.image_label_datas) * self.repeats |
| |
|
| | @property |
| | def modality_length(self): |
| | length_list = [] |
| | for data_dict in self.image_label_datas: |
| | length_list.append(100) |
| | length_list = length_list * self.repeats |
| | return length_list |
| |
|
| | def real_len(self): |
| | return len(self.image_label_datas) |
| |
|
| | def decode_mask(self, label_path): |
| | label = np.array(Image.open(label_path)) |
| |
|
| | |
| | label = np.where(label == 0, 255, label - 1) |
| | unique_labels = [lbl for lbl in np.unique(label) if lbl != 255] |
| | if not unique_labels: |
| | return None, None |
| |
|
| | selected_labels = np.random.choice(unique_labels, min( |
| | len(unique_labels), self.num_classes_per_sample), replace=False) |
| | label = torch.from_numpy(label).long() |
| | masks = torch.stack([label == class_id for class_id in selected_labels], dim=0) |
| | return masks, selected_labels |
| |
|
| | def __getitem__(self, index): |
| | index = index % self.real_len() |
| | data_dict = copy.deepcopy(self.image_label_datas[index]) |
| |
|
| | assert 'image' in data_dict.keys() |
| | if data_dict.get('image', None) is not None: |
| | image_file = data_dict['image'] |
| | image = Image.open(image_file).convert('RGB') |
| | if hasattr(self, 'extra_image_processor'): |
| | g_image = np.array(image) |
| | g_image = self.extra_image_processor.apply_image(g_image) |
| | g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous() |
| | data_dict['g_pixel_values'] = g_pixel_values |
| |
|
| | ori_width, ori_height = image.size |
| | if self.pad_image_to_square: |
| | image = expand2square(image, tuple(int(x * 255) |
| | for x in self.image_processor.image_mean)) |
| | image = self.image_processor.preprocess( |
| | image, return_tensors='pt')['pixel_values'][0] |
| | data_dict['pixel_values'] = image |
| |
|
| | |
| | data_dict['masks'], class_id = self.decode_mask(data_dict['label']) |
| | if class_id is None: |
| | return self.__getitem__(0) |
| |
|
| | if self.gcg_format: |
| | pass |
| | else: |
| | conversation = [] |
| | for i, c_id in enumerate(class_id): |
| | question = random.choice(SEG_QUESTIONS).format( |
| | class_name=self.classes[c_id].lower()) |
| | if i == 0: |
| | question = f"""The {DEFAULT_IMAGE_TOKEN} provides an overview of the picture.\n""" + question |
| | conversation.append( |
| | {'input': question, 'output': random.choice(ANSWER_LIST)}) |
| |
|
| | data_dict.update({'conversation': conversation}) |
| | else: |
| | if hasattr(self.image_processor, 'crop_size'): |
| | crop_size = self.image_processor.crop_size |
| | else: |
| | crop_size = self.image_processor.size |
| | data_dict['pixel_values'] = torch.zeros(3, crop_size['height'], |
| | crop_size['width']) |
| | data_dict['masks'] = None |
| |
|
| | if self.lazy: |
| | result = self.template_map_fn(data_dict) |
| | data_dict.update(result) |
| |
|
| | result = encode_fn(data_dict, tokenizer=self.tokenizer, |
| | max_length=self.max_length, with_image_token=True) |
| | data_dict.update(result) |
| |
|
| | return data_dict |
| |
|
| | class ADE20kSemanticSegDataset(SemanticSegDataset): |
| | def __init__(self, |
| | image_folder, |
| | image_processor, |
| | data_path=None, |
| | tokenizer=None, |
| | offline_processed_text_folder=None, |
| | max_dataset_length=None, |
| | dataset_map_fn=None, |
| | template_map_fn=None, |
| | max_length=2048, |
| | pad_image_to_square=False, |
| | num_proc=8, |
| | lazy=False, |
| | repeats=1, |
| | gcg_format=False, |
| | num_classes_per_sample=3, |
| | extra_image_processor=None): |
| | super().__init__( |
| | image_folder=image_folder, |
| | image_processor=image_processor, |
| | data_path=data_path, |
| | tokenizer=tokenizer, |
| | offline_processed_text_folder=offline_processed_text_folder, |
| | max_dataset_length=max_dataset_length, |
| | dataset_map_fn=dataset_map_fn, |
| | template_map_fn=template_map_fn, |
| | max_length=max_length, |
| | pad_image_to_square=pad_image_to_square, |
| | num_proc=num_proc, |
| | lazy=lazy, |
| | repeats=repeats, |
| | gcg_format=gcg_format, |
| | num_classes_per_sample=num_classes_per_sample, |
| | extra_image_processor=extra_image_processor, |
| | ) |
| |
|
| | class COCOStuffSemanticSegDataset(SemanticSegDataset): |
| | def __init__(self, |
| | image_folder, |
| | image_processor, |
| | data_path=None, |
| | tokenizer=None, |
| | offline_processed_text_folder=None, |
| | max_dataset_length=None, |
| | dataset_map_fn=None, |
| | template_map_fn=None, |
| | max_length=2048, |
| | pad_image_to_square=False, |
| | num_proc=8, |
| | lazy=False, |
| | repeats=1, |
| | label_path=None, |
| | gcg_format=False, |
| | num_classes_per_sample=3, |
| | extra_image_processor=None): |
| | self.label_path = label_path |
| | super().__init__( |
| | image_folder=image_folder, |
| | image_processor=image_processor, |
| | data_path=data_path, |
| | tokenizer=tokenizer, |
| | offline_processed_text_folder=offline_processed_text_folder, |
| | max_dataset_length=max_dataset_length, |
| | dataset_map_fn=dataset_map_fn, |
| | template_map_fn=template_map_fn, |
| | max_length=max_length, |
| | pad_image_to_square=pad_image_to_square, |
| | num_proc=num_proc, |
| | lazy=lazy, |
| | repeats=repeats, |
| | gcg_format=gcg_format, |
| | num_classes_per_sample=num_classes_per_sample, |
| | extra_image_processor=extra_image_processor, |
| | ) |
| | self.cocostuff_class2index = {c: i for i, c in enumerate(self.classes)} |
| |
|
| | def json_file_preprocess(self, data_path, image_folder): |
| | |
| | assert self.label_path is not None |
| | with open(data_path, 'r') as file: |
| | cocostuff_classes = [line.strip().split(": ")[-1] |
| | for line in file.readlines()[1:]] |
| | coco_stuff_image_dir = image_folder |
| | coco_stuff_label_dir = self.label_path |
| | coco_stuff_labels = glob.glob( |
| | os.path.join(coco_stuff_label_dir, "*.png")) |
| |
|
| | coco_stuff_images = [label.replace(".png", ".jpg").replace(coco_stuff_label_dir, coco_stuff_image_dir) |
| | for label in coco_stuff_labels] |
| |
|
| | self.classes = np.array(cocostuff_classes) |
| |
|
| | ret = [] |
| | for image, label in zip(coco_stuff_images, coco_stuff_labels): |
| | ret.append({"image": image, "label": label}) |
| | return ret |
| |
|
| | def decode_mask(self, label_path): |
| | label = np.array(Image.open(label_path)) |
| |
|
| | |
| | ignored_classes = [index for class_name, |
| | index in self.cocostuff_class2index.items() if "-" in class_name] |
| | label = np.where(np.isin(label, ignored_classes), 255, label) |
| |
|
| | unique_labels = [lbl for lbl in np.unique(label) if lbl != 255] |
| | if not unique_labels: |
| | print("No valid label !!!") |
| | return None, None |
| |
|
| | |
| | selected_labels = np.random.choice(unique_labels, min( |
| | len(unique_labels), self.num_classes_per_sample), replace=False) |
| |
|
| | label = torch.from_numpy(label).long() |
| | masks = torch.stack( |
| | [label == class_id for class_id in selected_labels], dim=0) |
| | return masks, selected_labels |
| |
|
| | class PascalPartSemanticSegDataset(SemanticSegDataset): |
| |
|
| | def json_file_preprocess(self, data_path, image_folder): |
| | self.coco_api = COCO(data_path) |
| | img_ids = self.coco_api.getImgIds() |
| | all_classes = self.coco_api.loadCats(self.coco_api.getCatIds()) |
| | class_map_pascal_part = {} |
| | for cat in all_classes: |
| | cat_main, cat_part = cat["name"].strip().split(":") |
| | name = (cat_main, cat_part) |
| | class_map_pascal_part[cat["id"]] = name |
| | self.classes = class_map_pascal_part |
| | return img_ids |
| |
|
| | def __getitem__(self, index): |
| | index = index % self.real_len() |
| | img_id = self.image_label_datas[index] |
| | img_info = self.coco_api.loadImgs([img_id])[0] |
| | file_name = img_info["file_name"] |
| | data_dict = {} |
| |
|
| | image_file = os.path.join(self.image_folder, file_name) |
| | image = Image.open(image_file).convert('RGB') |
| |
|
| | if hasattr(self, 'extra_image_processor'): |
| | g_image = np.array(image) |
| | g_image = self.extra_image_processor.apply_image(g_image) |
| | g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous() |
| | data_dict['g_pixel_values'] = g_pixel_values |
| |
|
| | if self.pad_image_to_square: |
| | image = expand2square( |
| | image, tuple(int(x * 255) for x in self.image_processor.image_mean)) |
| | image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] |
| | data_dict['pixel_values'] = image |
| |
|
| | annotation_ids = self.coco_api.getAnnIds(imgIds=img_info["id"]) |
| | annotations = self.coco_api.loadAnns(annotation_ids) |
| |
|
| | if not annotations: |
| | return self.__getitem__(0) |
| |
|
| | sampled_anns = np.random.choice(annotations, min( |
| | len(annotations), self.num_classes_per_sample), replace=False) |
| |
|
| | conversation = [] |
| | for i, ann in enumerate(sampled_anns): |
| | cat_id = ann['category_id'] |
| | sampled_cls = self.classes[cat_id] |
| | if isinstance(sampled_cls, tuple): |
| | obj, part = sampled_cls |
| | name = f"{obj} {part}" if random.random() < 0.5 else f"the {part} of the {obj}" |
| | else: |
| | name = sampled_cls |
| | question = random.choice(SEG_QUESTIONS).format(class_name=name) |
| | if i == 0: |
| | question = f"""The {DEFAULT_IMAGE_TOKEN} provides an overview of the picture.\n""" + question |
| | conversation.append( |
| | {'input': question, 'output': random.choice(ANSWER_LIST)}) |
| |
|
| | masks = [self.coco_api.annToMask(ann) for ann in sampled_anns] |
| | masks = np.stack(masks, axis=0) |
| | masks = torch.from_numpy(masks) |
| |
|
| | data_dict['masks'] = masks |
| | data_dict['conversation'] = conversation |
| |
|
| | if self.lazy: |
| | result = self.template_map_fn(data_dict) |
| | data_dict.update(result) |
| |
|
| | result = encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length, with_image_token=True) |
| | data_dict.update(result) |
| |
|
| | return data_dict |
| |
|
| | class PacoSemanticSegDataset(PascalPartSemanticSegDataset): |
| | def json_file_preprocess(self, data_path, image_folder): |
| | self.coco_api = COCO(data_path) |
| | all_classes = self.coco_api.loadCats(self.coco_api.getCatIds()) |
| | class_map_paco = {} |
| | for cat in all_classes: |
| | cat_split = cat["name"].strip().split(":") |
| | if len(cat_split) == 1: |
| | name = cat_split[0].split("_(")[0] |
| | else: |
| | assert len(cat_split) == 2 |
| | obj, part = cat_split |
| | obj = obj.split("_(")[0] |
| | part = part.split("_(")[0] |
| | name = (obj, part) |
| | class_map_paco[cat["id"]] = name |
| | self.classes = class_map_paco |
| | return self.coco_api.getImgIds() |