DenseLabelDev / projects /omg_llava /dataset /SemanticSegDataset.py
zhouyik's picture
Upload folder using huggingface_hub
032e687 verified
import random
import glob
import json
import logging
import os
import torch
from datasets import Dataset as HFDataset
from datasets import DatasetDict, load_from_disk
from mmengine import print_log
from mmengine.config import Config, ConfigDict
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
import torch.nn.functional as F
from pycocotools.coco import COCO
from xtuner.registry import BUILDER
from .utils import expand2square, expand2square_mask
from .process_functions.semantic_seg_process import semantic_seg_conversations, semantic_seg_gcg_format_conversations
from xtuner.dataset.huggingface import process_hf_dataset, build_origin_dataset
import copy
from xtuner.dataset.utils import encode_fn
class SemanticSegDataset(Dataset):
def __init__(self,
image_folder,
image_processor,
data_path=None,
tokenizer=None,
offline_processed_text_folder=None,
max_dataset_length=None,
dataset_map_fn=None,
template_map_fn=None,
max_length=2048,
pad_image_to_square=False,
num_proc=8,
lazy=False,
repeats=1,
gcg_format=False):
super().__init__()
self.tokenizer = tokenizer
assert offline_processed_text_folder or (data_path and tokenizer)
self.lazy = lazy
self.max_length = max_length
self.dataset_map_fn = dataset_map_fn
self.template_map_fn = template_map_fn
if isinstance(self.template_map_fn, dict) and self.lazy:
_type = self.template_map_fn['type']
del self.template_map_fn['type']
self.template_map_fn = _type(**self.template_map_fn)
if offline_processed_text_folder and data_path:
print_log(
'Both `offline_processed_text_folder` and '
'`data_path` are set, and we load dataset from'
'`offline_processed_text_folder` '
f'({offline_processed_text_folder})',
logger='current',
level=logging.WARNING)
if offline_processed_text_folder is not None:
raise NotImplementedError
else:
self.image_label_datas = self.json_file_preprocess(data_path, image_folder)
if gcg_format:
conversations_datas = semantic_seg_gcg_format_conversations(self.classes)
else:
conversations_datas = semantic_seg_conversations(self.classes)
json_data = DatasetDict({'train': HFDataset.from_list(conversations_datas)})
if self.lazy:
self.text_data = build_origin_dataset(json_data, 'train')
else:
self.text_data = process_hf_dataset(
dataset=json_data,
tokenizer=tokenizer,
max_length=max_length,
dataset_map_fn=dataset_map_fn,
template_map_fn=template_map_fn,
split='train',
max_dataset_length=max_dataset_length,
remove_unused_columns=False,
pack_to_max_length=False,
with_image_token=True,
map_num_proc=num_proc, # because limited mem
)
self.clsid2convs = self.construct_cls2convs_dict()
self.image_folder = image_folder
size = image_processor.crop_size
if isinstance(size, int):
self.image_h, self.image_w = size, size
else:
self.image_w, self.image_h = size
if isinstance(image_processor, dict) or isinstance(
image_processor, Config) or isinstance(image_processor,
ConfigDict):
self.image_processor = BUILDER.build(image_processor)
else:
self.image_processor = image_processor
self.pad_image_to_square = pad_image_to_square
self.down_ratio = 1
self.repeats = repeats
self.tokenizer = tokenizer
def construct_cls2convs_dict(self):
ret = {}
for conv_item in self.text_data:
cls_id = conv_item['class_id']
if cls_id in ret.keys():
ret[cls_id].append(conv_item)
else:
ret[cls_id] = [conv_item]
return ret
def json_file_preprocess(self, data_path, image_folder):
# ade20k
with open(data_path, 'r') as file:
ade20k_classes = json.load(file)
ade20k_image_dir = image_folder
ade20k_images = [os.path.join(ade20k_image_dir, img) for img in os.listdir(ade20k_image_dir) if
img.endswith('.jpg')]
ade20k_labels = [img.replace(".jpg", ".png").replace("images", "annotations") for img in ade20k_images]
self.classes = np.array(ade20k_classes)
ret = []
for image, label in zip(ade20k_images, ade20k_labels):
ret.append({"image": image, "label": label})
return ret
def __len__(self):
return len(self.image_label_datas) * self.repeats
@property
def modality_length(self):
length_list = []
for data_dict in self.image_label_datas:
length_list.append(100)
length_list = length_list * self.repeats
return length_list
def real_len(self):
return len(self.image_label_datas)
def decode_mask(self, label_path):
label = np.array(Image.open(label_path))
# ade 20k
label = np.where(label == 0, 255, label - 1)
unique_labels = [lbl for lbl in np.unique(label) if lbl != 255]
if not unique_labels:
return None, None
# only choose 1
selected_labels = np.random.choice(
unique_labels, 1, replace=False
)
label = torch.from_numpy(label).long()
masks = torch.stack([label == class_id for class_id in selected_labels], dim=0)
masks = masks.numpy()
if self.pad_image_to_square:
masks = expand2square_mask(masks)
masks = torch.from_numpy(masks).to(torch.float32)
masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio,
self.image_w // self.down_ratio), mode='nearest').squeeze(0)
return masks, selected_labels[0]
def __getitem__(self, index):
index = index % self.real_len()
data_dict = copy.deepcopy(self.image_label_datas[index])
assert 'image' in data_dict.keys()
if data_dict.get('image', None) is not None:
image_file = data_dict['image']
image = Image.open(image_file).convert('RGB')
ori_width, ori_height = image.size
if self.pad_image_to_square:
image = expand2square(
image,
tuple(
int(x * 255) for x in self.image_processor.image_mean))
image = self.image_processor.preprocess(
image, return_tensors='pt')['pixel_values'][0]
data_dict['pixel_values'] = image
# process and get masks
data_dict['masks'], class_id = self.decode_mask(data_dict['label'])
if class_id is None:
return self.__getitem__(0)
conv_datas = self.clsid2convs[class_id]
selected_idx = np.random.randint(0, len(conv_datas))
data_dict.update(conv_datas[selected_idx])
else:
if hasattr(self.image_processor, 'crop_size'):
crop_size = self.image_processor.crop_size
else:
crop_size = self.image_processor.size
data_dict['pixel_values'] = torch.zeros(3, crop_size['height'],
crop_size['width'])
data_dict['masks'] = None
if self.lazy:
result = self.dataset_map_fn(data_dict)
data_dict.update(result)
result = self.template_map_fn(data_dict)
data_dict.update(result)
result = encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length, with_image_token=True)
data_dict.update(result)
return data_dict
class ADE20kSemanticSegDataset(SemanticSegDataset):
def __init__(self,
image_folder,
image_processor,
data_path=None,
tokenizer=None,
offline_processed_text_folder=None,
max_dataset_length=None,
dataset_map_fn=None,
template_map_fn=None,
max_length=2048,
pad_image_to_square=False,
num_proc=8,
lazy=False,
repeats=1,
gcg_format=False):
super().__init__(
image_folder=image_folder,
image_processor=image_processor,
data_path=data_path,
tokenizer=tokenizer,
offline_processed_text_folder=offline_processed_text_folder,
max_dataset_length=max_dataset_length,
dataset_map_fn=dataset_map_fn,
template_map_fn=template_map_fn,
max_length=max_length,
pad_image_to_square=pad_image_to_square,
num_proc=num_proc,
lazy=lazy,
repeats=repeats,
gcg_format=gcg_format,
)
class COCOStuffSemanticSegDataset(SemanticSegDataset):
def __init__(self,
image_folder,
image_processor,
data_path=None,
tokenizer=None,
offline_processed_text_folder=None,
max_dataset_length=None,
dataset_map_fn=None,
template_map_fn=None,
max_length=2048,
pad_image_to_square=False,
num_proc=8,
lazy=False,
repeats=1,
label_path=None,
gcg_format=False,):
self.label_path = label_path
super().__init__(
image_folder=image_folder,
image_processor=image_processor,
data_path=data_path,
tokenizer=tokenizer,
offline_processed_text_folder=offline_processed_text_folder,
max_dataset_length=max_dataset_length,
dataset_map_fn=dataset_map_fn,
template_map_fn=template_map_fn,
max_length=max_length,
pad_image_to_square=pad_image_to_square,
num_proc=num_proc,
lazy=lazy,
repeats=repeats,
gcg_format=gcg_format,
)
self.cocostuff_class2index = {c: i for i, c in enumerate(self.classes)}
def json_file_preprocess(self, data_path, image_folder):
# coco stuff
assert self.label_path is not None
with open(data_path, 'r') as file:
cocostuff_classes = [line.strip().split(": ")[-1] for line in file.readlines()[1:]]
coco_stuff_image_dir = image_folder
coco_stuff_label_dir = self.label_path
coco_stuff_labels = glob.glob(os.path.join(coco_stuff_label_dir, "*.png"))
coco_stuff_images = [label.replace(".png", ".jpg").replace(coco_stuff_label_dir, coco_stuff_image_dir)
for label in coco_stuff_labels]
self.classes = np.array(cocostuff_classes)
ret = []
for image, label in zip(coco_stuff_images, coco_stuff_labels):
ret.append({"image": image, "label": label})
return ret
def decode_mask(self, label_path):
label = np.array(Image.open(label_path))
# coco stuff
ignored_classes = [index for class_name, index in self.cocostuff_class2index.items() if
"-" in class_name]
label = np.where(np.isin(label, ignored_classes), 255, label)
unique_labels = [lbl for lbl in np.unique(label) if lbl != 255]
if not unique_labels:
print("No valid label !!!")
return None, None
# only choose 1
selected_labels = np.random.choice(
unique_labels, 1, replace=False
)
label = torch.from_numpy(label).long()
masks = torch.stack([label == class_id for class_id in selected_labels], dim=0)
masks = masks.numpy()
if self.pad_image_to_square:
masks = expand2square_mask(masks)
masks = torch.from_numpy(masks).to(torch.float32)
masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio,
self.image_w // self.down_ratio), mode='nearest').squeeze(0)
return masks, selected_labels[0]
class MapillarySemanticSegDataset(SemanticSegDataset):
def __init__(self,
image_folder,
image_processor,
data_path=None,
tokenizer=None,
offline_processed_text_folder=None,
max_dataset_length=None,
dataset_map_fn=None,
template_map_fn=None,
max_length=2048,
pad_image_to_square=False,
num_proc=8,
lazy=False,
repeats=1,
label_path=None,
gcg_format=False,):
self.label_path = label_path
super().__init__(
image_folder=image_folder,
image_processor=image_processor,
data_path=data_path,
tokenizer=tokenizer,
offline_processed_text_folder=offline_processed_text_folder,
max_dataset_length=max_dataset_length,
dataset_map_fn=dataset_map_fn,
template_map_fn=template_map_fn,
max_length=max_length,
pad_image_to_square=pad_image_to_square,
num_proc=num_proc,
lazy=lazy,
repeats=repeats,
gcg_format=gcg_format,
)
def json_file_preprocess(self, data_path, image_folder):
assert self.label_path is not None
# mapillary
with open(data_path, 'r') as file:
mapillary_classes = json.load(file)["labels"]
mapillary_classes = [cls["readable"].lower() for cls in mapillary_classes]
mapillary_labels = sorted(
glob.glob(os.path.join(self.label_path, "*.png")))
mapillary_images = [
label.replace(".png", ".jpg").replace(self.label_path, image_folder)
for label in mapillary_labels]
self.classes = np.array(mapillary_classes)
ret = []
for image, label in zip(mapillary_images, mapillary_labels):
ret.append({"image": image, "label": label})
return ret
def decode_mask(self, label_path):
label = np.array(Image.open(label_path))
ignored_classes = [index for index, class_name in enumerate(self.classes) if
"-" in class_name or '(' in class_name or
'unlabeled' in class_name]
label = np.where(np.isin(label, ignored_classes), 255, label)
unique_labels = [lbl for lbl in np.unique(label) if lbl != 255]
if not unique_labels:
print("No valid label !!!")
return None, None
# only choose 1
selected_labels = np.random.choice(
unique_labels, 1, replace=False
)
label = torch.from_numpy(label).long()
masks = torch.stack([label == class_id for class_id in selected_labels], dim=0)
masks = masks.numpy()
if self.pad_image_to_square:
masks = expand2square_mask(masks)
masks = torch.from_numpy(masks).to(torch.float32)
masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio,
self.image_w // self.down_ratio), mode='nearest').squeeze(0)
return masks, selected_labels[0]
class PascalPartSemanticSegDataset(Dataset):
def __init__(self,
image_folder,
image_processor,
data_path=None,
tokenizer=None,
offline_processed_text_folder=None,
max_dataset_length=None,
dataset_map_fn=None,
template_map_fn=None,
max_length=2048,
pad_image_to_square=False,
num_proc=8,
lazy=False,
repeats=1):
super().__init__()
self.tokenizer = tokenizer
assert offline_processed_text_folder or (data_path and tokenizer)
self.lazy = lazy
self.max_length = max_length
self.dataset_map_fn = dataset_map_fn
self.template_map_fn = template_map_fn
if isinstance(self.template_map_fn, dict) and self.lazy:
_type = self.template_map_fn['type']
del self.template_map_fn['type']
self.template_map_fn = _type(**self.template_map_fn)
if offline_processed_text_folder and data_path:
print_log(
'Both `offline_processed_text_folder` and '
'`data_path` are set, and we load dataset from'
'`offline_processed_text_folder` '
f'({offline_processed_text_folder})',
logger='current',
level=logging.WARNING)
if offline_processed_text_folder is not None:
raise NotImplementedError
else:
json_datas = self.json_file_preprocess(data_path)
json_data = DatasetDict({'train': HFDataset.from_list(json_datas)})
if self.lazy:
self.text_data = build_origin_dataset(json_data, 'train')
else:
self.text_data = process_hf_dataset(
dataset=json_data,
tokenizer=tokenizer,
max_length=max_length,
dataset_map_fn=dataset_map_fn,
template_map_fn=template_map_fn,
split='train',
max_dataset_length=max_dataset_length,
remove_unused_columns=False,
pack_to_max_length=False,
with_image_token=True,
map_num_proc=num_proc, # because limited mem
)
self.image_folder = image_folder
size = image_processor.crop_size
if isinstance(size, int):
self.image_h, self.image_w = size, size
else:
self.image_w, self.image_h = size
if isinstance(image_processor, dict) or isinstance(
image_processor, Config) or isinstance(image_processor,
ConfigDict):
self.image_processor = BUILDER.build(image_processor)
else:
self.image_processor = image_processor
self.pad_image_to_square = pad_image_to_square
self.down_ratio = 1
self.repeats = repeats
self.tokenizer = tokenizer
def json_file_preprocess(self, data_path):
pascal_part_api = COCO(data_path)
all_classes = pascal_part_api.loadCats(pascal_part_api.getCatIds())
class_map_pascal_part = {}
for cat in all_classes:
cat_main, cat_part = cat["name"].strip().split(":")
name = (cat_main, cat_part)
class_map_pascal_part[cat["id"]] = name
img_ids = pascal_part_api.getImgIds()
self.classes = class_map_pascal_part
self.coco_api = pascal_part_api
img_infos = [self.coco_api.loadImgs([img_id])[0] for img_id in img_ids]
valid_img_infos = []
for img_info in img_infos:
annotation_ids = self.coco_api.getAnnIds(imgIds=img_info["id"])
annotations = self.coco_api.loadAnns(annotation_ids)
if not annotations:
continue
# sampled to max number as 5
sampled_anns = np.random.choice(annotations, 5, replace=False) if len(
annotations
) >= 5 else annotations
selected_labels = []
for ann in sampled_anns:
category_id = ann["category_id"]
sampled_cls = self.classes[category_id]
if isinstance(sampled_cls, tuple):
obj, part = sampled_cls
name = f"{obj} {part}" if random.random() < 0.5 else f"the {part} of the {obj}"
else:
name = sampled_cls
selected_labels.append(name)
img_info.update({"annotations": sampled_anns,
"selected_labels": selected_labels})
valid_img_infos.append(img_info)
return valid_img_infos
def __len__(self):
return len(self.text_data) * self.repeats
@property
def modality_length(self):
length_list = []
for data_dict in self.text_data:
if self.lazy:
cur_len = 100
else:
cur_len = len(data_dict['input_ids'])
if data_dict.get('image', None) is None:
cur_len = -cur_len
length_list.append(cur_len)
length_list = length_list * self.repeats
return length_list
def real_len(self):
return len(self.text_data)
def decode_mask(self, annotations):
try:
masks = [self.coco_api.annToMask(ann) for ann in annotations]
except Exception as e:
print(f"Error generating mask: {e}")
return None
masks = np.stack(masks, axis=0)
if self.pad_image_to_square:
masks = expand2square_mask(masks)
masks = torch.from_numpy(masks)
masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio,
self.image_w // self.down_ratio), mode='nearest').squeeze(0)
return masks
def __getitem__(self, index):
index = index % self.real_len()
data_dict = copy.deepcopy(self.text_data[index])
assert 'image' in data_dict.keys()
if data_dict.get('image', None) is not None:
image_file = data_dict['image']
image_file = os.path.join(self.image_folder, image_file)
image = Image.open(image_file).convert('RGB')
ori_width, ori_height = image.size
if self.pad_image_to_square:
image = expand2square(
image,
tuple(
int(x * 255) for x in self.image_processor.image_mean))
image = self.image_processor.preprocess(
image, return_tensors='pt')['pixel_values'][0]
data_dict['pixel_values'] = image
# process and get masks
data_dict['masks'] = self.decode_mask(data_dict['annotations'])
if data_dict['masks'] is None:
return self.__getitem__(0)
else:
if hasattr(self.image_processor, 'crop_size'):
crop_size = self.image_processor.crop_size
else:
crop_size = self.image_processor.size
data_dict['pixel_values'] = torch.zeros(3, crop_size['height'],
crop_size['width'])
data_dict['masks'] = None
if self.lazy:
result = self.dataset_map_fn(data_dict)
data_dict.update(result)
result = self.template_map_fn(data_dict)
data_dict.update(result)
result = encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length, with_image_token=True)
data_dict.update(result)
return data_dict
class PacoSemanticSegDataset(PascalPartSemanticSegDataset):
def __init__(self,
image_folder,
image_processor,
data_path=None,
tokenizer=None,
offline_processed_text_folder=None,
max_dataset_length=None,
dataset_map_fn=None,
template_map_fn=None,
max_length=2048,
pad_image_to_square=False,
num_proc=8,
lazy=False,
repeats=1,):
self.tokenizer = tokenizer
assert offline_processed_text_folder or (data_path and tokenizer)
self.lazy = lazy
self.max_length = max_length
self.dataset_map_fn = dataset_map_fn
self.template_map_fn = template_map_fn
if isinstance(self.template_map_fn, dict) and self.lazy:
_type = self.template_map_fn['type']
del self.template_map_fn['type']
self.template_map_fn = _type(**self.template_map_fn)
if offline_processed_text_folder and data_path:
print_log(
'Both `offline_processed_text_folder` and '
'`data_path` are set, and we load dataset from'
'`offline_processed_text_folder` '
f'({offline_processed_text_folder})',
logger='current',
level=logging.WARNING)
if offline_processed_text_folder is not None:
raise NotImplementedError
else:
json_datas = self.json_file_preprocess(data_path)
self.json_datas = json_datas
json_datas = self.only_get_hf_map_infos()
json_data = DatasetDict({'train': HFDataset.from_list(json_datas)})
if self.lazy:
self.text_data = build_origin_dataset(json_data, 'train')
else:
self.text_data = process_hf_dataset(
dataset=json_data,
tokenizer=tokenizer,
max_length=max_length,
dataset_map_fn=dataset_map_fn,
template_map_fn=template_map_fn,
split='train',
max_dataset_length=max_dataset_length,
remove_unused_columns=False,
pack_to_max_length=False,
with_image_token=True,
map_num_proc=num_proc, # because limited mem
)
self.image_folder = image_folder
size = image_processor.crop_size
if isinstance(size, int):
self.image_h, self.image_w = size, size
else:
self.image_w, self.image_h = size
if isinstance(image_processor, dict) or isinstance(
image_processor, Config) or isinstance(image_processor,
ConfigDict):
self.image_processor = BUILDER.build(image_processor)
else:
self.image_processor = image_processor
self.pad_image_to_square = pad_image_to_square
self.down_ratio = 1
self.repeats = repeats
self.tokenizer = tokenizer
def only_get_hf_map_infos(self):
ret = []
for json_data in self.json_datas:
ret.append({'file_name': json_data['file_name'],
'selected_labels': json_data['selected_labels']})
return ret
def json_file_preprocess(self, data_path):
paco_api = COCO(data_path)
all_classes = paco_api.loadCats(paco_api.getCatIds())
class_map_paco = {}
for cat in all_classes:
cat_split = cat["name"].strip().split(":")
if len(cat_split) == 1:
name = cat_split[0].split("_(")[0]
else:
assert len(cat_split) == 2
obj, part = cat_split
obj = obj.split("_(")[0]
part = part.split("_(")[0]
name = (obj, part)
class_map_paco[cat["id"]] = name
img_ids = paco_api.getImgIds()
self.classes = class_map_paco
self.coco_api = paco_api
img_infos = [self.coco_api.loadImgs([img_id])[0] for img_id in img_ids]
valid_img_infos = []
for img_info in img_infos:
annotation_ids = self.coco_api.getAnnIds(imgIds=img_info["id"])
annotations = self.coco_api.loadAnns(annotation_ids)
if not annotations:
continue
# sampled to max number as 5
sampled_anns = np.random.choice(annotations, 5, replace=False) if len(
annotations
) >= 5 else annotations
selected_labels = []
for ann in sampled_anns:
category_id = ann["category_id"]
sampled_cls = self.classes[category_id]
if isinstance(sampled_cls, tuple):
obj, part = sampled_cls
name = f"{obj} {part}" if random.random() < 0.5 else f"the {part} of the {obj}"
else:
name = sampled_cls
selected_labels.append(name)
img_info.update({"annotations": sampled_anns,
"selected_labels": selected_labels})
valid_img_infos.append(img_info)
return valid_img_infos
def __getitem__(self, index):
index = index % self.real_len()
data_dict = copy.deepcopy(self.text_data[index])
data_dict.update(self.json_datas[index])
assert 'image' in data_dict.keys()
if data_dict.get('image', None) is not None:
image_file = data_dict['image']
image_file = os.path.join(self.image_folder, image_file)
image = Image.open(image_file).convert('RGB')
ori_width, ori_height = image.size
if self.pad_image_to_square:
image = expand2square(
image,
tuple(
int(x * 255) for x in self.image_processor.image_mean))
image = self.image_processor.preprocess(
image, return_tensors='pt')['pixel_values'][0]
data_dict['pixel_values'] = image
# process and get masks
data_dict['masks'] = self.decode_mask(data_dict['annotations'])
if data_dict['masks'] is None:
return self.__getitem__(0)
else:
if hasattr(self.image_processor, 'crop_size'):
crop_size = self.image_processor.crop_size
else:
crop_size = self.image_processor.size
data_dict['pixel_values'] = torch.zeros(3, crop_size['height'],
crop_size['width'])
data_dict['masks'] = None
if self.lazy:
result = self.dataset_map_fn(data_dict)
data_dict.update(result)
result = self.template_map_fn(data_dict)
data_dict.update(result)
result = encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length, with_image_token=True)
data_dict.update(result)
return data_dict