PixDLM / utils /dataset.py
WhynotHug's picture
Upload folder using huggingface_hub
3334467 verified
Raw
History Blame Contribute Delete
29.3 kB
import glob
import os
from queue import Empty
import random
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from pycocotools import mask
from transformers import CLIPImageProcessor
import transformers
from .muse import CustomSegDataset
from model.llava.mm_utils import tokenizer_image_token
from model.segment_anything.utils.transforms import ResizeLongestSide, ResizeShortestSide
from .data_processing import get_mask_from_json
from .reason_seg_dataset import ReasonSegDataset
from .refer import REFER
from .refer_seg_dataset import ReferSegDataset
from .sem_seg_dataset import SemSegDataset
from .vqa_dataset import VQADataset
from .multi_reason_seg_dataset import MultiReasonSegDataset
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
from model.llava import conversation as conversation_lib
from .utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
DEFAULT_IMAGE_TOKEN)
def collate_fn(
batch, tokenizer=None, conv_type="llava_v1", use_mm_start_end=True, local_rank=-1
):
image_path_list = []
images_list = []
images_clip_list = []
conversation_list = []
masks_list = []
label_list = []
resize_list = []
questions_list = []
sampled_classes_list = []
clip_resize_list = []
offset_list = [0]
cnt = 0
inferences = []
multi_reasons = []
categories = []
answers_list = []
for item in batch:
(image_path, images, images_clip, conversations, masks, label,
resize, clip_resize, questions, sampled_classes, *rest) = item
multi_reason = False
inference = False
category = 'unknown'
answers = None
if len(rest) >= 1:
multi_reason = rest[0]
if len(rest) >= 2:
if isinstance(rest[1], (bool, np.bool_)) if 'np' in globals() else isinstance(rest[1], bool):
inference = rest[1]
if len(rest) >= 3:
category = rest[2] if isinstance(rest[2], str) else 'unknown'
if len(rest) >= 4:
answers = rest[3]
else:
category = rest[1] if isinstance(rest[1], str) else 'unknown'
if len(rest) >= 3:
answers = rest[2]
image_path_list.append(image_path)
images_list.append(images)
images_clip_list.append(images_clip)
conversation_list.extend(conversations)
label_list.append(label)
masks_list.append(masks.float())
resize_list.append(resize)
clip_resize_list.append(clip_resize)
questions_list.append(questions)
sampled_classes_list.append(sampled_classes)
cnt += len(conversations)
offset_list.append(cnt)
inferences.append(inference)
multi_reasons.append(multi_reason)
categories.append(category)
answers_list.append(answers)
if use_mm_start_end:
for i in range(len(conversation_list)):
replace_token = DEFAULT_IMAGE_TOKEN
replace_token = (
DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
)
conversation_list[i] = conversation_list[i].replace(
DEFAULT_IMAGE_TOKEN, replace_token
)
input_ids = [
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
for prompt in conversation_list
]
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=tokenizer.pad_token_id
)
attention_masks = input_ids.ne(tokenizer.pad_token_id)
conv = conversation_lib.conv_templates['chatml'].copy() if conv_type == "chatml" else conversation_lib.default_conversation.copy()
targets = input_ids.clone()
if conv_type == "llava_v1" or "chatml":
sep = conv.sep + conv.roles[1] + ": "
else:
sep = "[/INST] "
for conversation, target in zip(conversation_list, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
if conv.sep2 not in conversation:
break
rounds = conversation.split(conv.sep2)
cur_len = 1
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(rounds):
if rou == "":
break
if conv_type == "chatml":
if DEFAULT_IMAGE_TOKEN in conversation:
round_len = len(tokenizer_image_token(rou, tokenizer))
instruction_len = len(tokenizer_image_token(rou+sep, tokenizer)) - 2
else:
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(rou+sep).input_ids) - 2
if i == 0:
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
else:
parts = rou.split(sep)
assert len(parts) == 2, (len(parts), rou)
parts[0] += sep
if DEFAULT_IMAGE_TOKEN in conversation:
round_len = len(tokenizer_image_token(rou, tokenizer))
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
else:
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
if conv_type == "chatml":
cur_len = total_len
target[cur_len:] = IGNORE_INDEX
if False:
z = target.clone()
z = torch.where(z == IGNORE_INDEX, tokenizer.unk_token_id, z)
if local_rank == 0:
print(
"conversation: ",
conversation,
"tokenizer.decode(z): ",
tokenizer.decode(z),
)
if cur_len < tokenizer.model_max_length:
assert cur_len == total_len
if inferences[0] == False:
truncate_len = tokenizer.model_max_length - 255
if input_ids.shape[1] > truncate_len:
input_ids = input_ids[:, :truncate_len]
targets = targets[:, :truncate_len]
attention_masks = attention_masks[:, :truncate_len]
return {
"image_paths": image_path_list,
"images": torch.stack(images_list, dim=0),
"images_clip": torch.stack(images_clip_list, dim=0),
"input_ids": input_ids,
"labels": targets,
"attention_masks": attention_masks,
"masks_list": masks_list,
"label_list": label_list,
"resize_list": resize_list,
"clip_resize_list": clip_resize_list,
"offset": torch.LongTensor(offset_list),
"questions_list": questions_list,
"sampled_classes_list": sampled_classes_list,
"inference": inferences[0],
"conversation_list": conversation_list,
"multi_reason_list": multi_reasons,
"categories": categories,
"answers_list": answers_list,
}
class HybridDataset(torch.utils.data.Dataset):
pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
img_size = 1024
ignore_label = 255
def __init__(
self,
base_image_dir,
tokenizer,
vision_tower,
samples_per_epoch=500 * 8 * 2 * 10,
precision: str = "fp32",
image_size: int = 224,
num_classes_per_sample: int = 3,
exclude_val=False,
dataset="sem_seg||refer_seg||vqa||reason_seg",
sample_rate=[9, 3, 3, 1],
sem_seg_data="ade20k||cocostuff||partimagenet||pascal_part||paco_lvis||mapillary",
refer_seg_data="refclef||refcoco||refcoco+||refcocog",
vqa_data="llava_instruct_150k",
reason_seg_data="ReasonSeg|train",
explanatory=0.1,
seg_token_num=1,
num_classes_per_question=1,
pad_train_clip_images=False,
masks_process_with_clip=False,
preprocessor_config='',
use_expand_question_list=False,
):
self.pad_train_clip_images = pad_train_clip_images
self.exclude_val = exclude_val
self.dataset = dataset
self.samples_per_epoch = samples_per_epoch
self.explanatory = explanatory
self.num_classes_per_sample = num_classes_per_sample
sample_rate = np.array(sample_rate)
self.sample_rate = sample_rate / sample_rate.sum()
self.seg_token_num = seg_token_num
self.base_image_dir = base_image_dir
self.image_size = image_size
self.tokenizer = tokenizer
self.precision = precision
self.datasets = dataset.split("||")
self.all_datasets = []
for dataset in self.datasets:
if dataset == "sem_seg":
self.all_datasets.append(
SemSegDataset(
base_image_dir,
tokenizer,
vision_tower,
samples_per_epoch,
precision,
image_size,
num_classes_per_sample,
exclude_val,
sem_seg_data,
num_classes_per_question,
seg_token_num,
pad_train_clip_images,
masks_process_with_clip,
preprocessor_config,
use_expand_question_list,
)
)
elif dataset == "refer_seg":
self.all_datasets.append(
ReferSegDataset(
base_image_dir,
tokenizer,
vision_tower,
samples_per_epoch,
precision,
image_size,
num_classes_per_sample,
exclude_val,
refer_seg_data,
num_classes_per_question,
seg_token_num,
pad_train_clip_images,
masks_process_with_clip,
preprocessor_config,
use_expand_question_list,
)
)
elif dataset == "vqa":
self.all_datasets.append(
VQADataset(
base_image_dir,
tokenizer,
vision_tower,
samples_per_epoch,
precision,
image_size,
num_classes_per_sample,
exclude_val,
vqa_data,
pad_train_clip_images,
masks_process_with_clip,
preprocessor_config,
)
)
elif dataset == "reason_seg":
self.all_datasets.append(
ReasonSegDataset(
base_image_dir,
tokenizer,
vision_tower,
samples_per_epoch,
precision,
image_size,
num_classes_per_sample,
exclude_val,
reason_seg_data,
explanatory,
num_classes_per_question,
seg_token_num,
pad_train_clip_images,
masks_process_with_clip,
preprocessor_config,
use_expand_question_list,
)
)
elif dataset == "multi_reason_seg":
self.all_datasets.append(
MultiReasonSegDataset(
base_image_dir,
tokenizer,
vision_tower,
samples_per_epoch,
precision,
image_size,
num_classes_per_sample,
exclude_val,
reason_seg_data,
explanatory,
num_classes_per_question,
seg_token_num,
pad_train_clip_images,
masks_process_with_clip,
preprocessor_config,
use_expand_question_list
)
)
elif dataset == "custom_seg":
self.all_datasets.append(
CustomSegDataset(
os.path.join(base_image_dir, "CODrone/DRtrain"),
tokenizer,
vision_tower,
os.path.join(base_image_dir, "labels/DRSeg_train.json"),
samples_per_epoch=samples_per_epoch,
precision=precision,
image_size=image_size,
num_classes_per_sample=num_classes_per_sample,
exclude_val=exclude_val,
seg_token_num=seg_token_num,
pad_train_clip_images=pad_train_clip_images,
masks_process_with_clip=masks_process_with_clip,
preprocessor_config=preprocessor_config,
)
)
def __len__(self):
return self.samples_per_epoch
def __getitem__(self, idx):
ind = np.random.choice(list(range(len(self.datasets))), p=self.sample_rate)
data = self.all_datasets[ind]
inference = False
output = data[0]
if isinstance(data, (MultiReasonSegDataset,CustomSegDataset)) :
return *output, inference
else:
return *output, False, inference
class ValDataset(torch.utils.data.Dataset):
pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
img_size = 1024
ignore_label = 255
def __init__(
self,
base_image_dir,
tokenizer,
vision_tower,
val_dataset,
image_size=1024,
seg_token_num=1,
pad_val_clip_images=False,
masks_process_with_clip=False,
preprocessor_config='',
):
self.seg_token_num=seg_token_num
self.base_image_dir = base_image_dir
self.pad_val_clip_images = pad_val_clip_images
self.masks_process_with_clip = masks_process_with_clip
self.multiseg_inference = False
splits = val_dataset.split("|")
if len(splits) == 2:
ds, split = splits
if ds == "custom_seg":
self.data_type = "custom_seg"
from .muse import CustomSegDataset
if split == "val":
json_file_path = os.path.join(base_image_dir, "labels/DRSeg_val.json")
elif split == "test":
json_file_path = os.path.join(base_image_dir, "labels/DRSeg_test.json")
else:
json_file_path = os.path.join(base_image_dir, "labels/DRSeg_train.json")
if split == "test":
eval_image_dir = os.path.join(base_image_dir, "CODrone/DRtest")
else:
eval_image_dir = os.path.join(base_image_dir, "CODrone/DRval")
temp_dataset = CustomSegDataset(
base_image_dir=eval_image_dir,
tokenizer=tokenizer,
vision_tower=vision_tower,
json_file_path=json_file_path,
samples_per_epoch=0,
precision="bf16",
image_size=image_size,
num_classes_per_sample=1,
seg_token_num=seg_token_num,
pad_train_clip_images=pad_val_clip_images,
masks_process_with_clip=masks_process_with_clip,
preprocessor_config=preprocessor_config,
inference=True
)
self.custom_dataset = temp_dataset
self.images = list(range(len(temp_dataset)))
else:
images = glob.glob(
os.path.join(self.base_image_dir, "reason_seg", ds, split, "*.jpg")
)
self.images = images
self.data_type = "reason_seg"
elif len(splits) == 3:
ds, splitBy, split = splits
if 'multi' in ds:
self.multiseg_inference = True
ds = ds.split('multi')[-1]
if ds == 'rs_reason' or ds == 'rrsisd':
refer_api = REFER(self.base_image_dir, ds, splitBy)
else:
refer_api = REFER(self.base_image_dir+'/refer_seg/', ds, splitBy)
ref_ids_val = refer_api.getRefIds(split=split)
images_ids_val = refer_api.getImgIds(ref_ids=ref_ids_val)
refs_val = refer_api.loadRefs(ref_ids=ref_ids_val)
refer_seg_ds = {}
refer_seg_ds["images"] = []
loaded_images = refer_api.loadImgs(image_ids=images_ids_val)
for item in loaded_images:
item = item.copy()
if ds == "refclef":
item["file_name"] = os.path.join(
base_image_dir, "refer_seg/images/saiapr_tc-12", item["file_name"]
)
elif ds in ["refcoco", "refcoco+", "refcocog", "grefcoco"]:
item["file_name"] = os.path.join(
base_image_dir,
"refer_seg/images/mscoco/images/train2014",
item["file_name"],
)
elif ds == 'rrsisd':
item["file_name"] = os.path.join(
base_image_dir,
"rrsisd/images/rrsisd/JPEGImages",
item["file_name"],
)
refer_seg_ds["images"].append(item)
refer_seg_ds["annotations"] = refer_api.Anns
img2refs = {}
for ref in refs_val:
image_id = ref["image_id"]
img2refs[image_id] = img2refs.get(image_id, []) + [
ref,
]
refer_seg_ds["img2refs"] = img2refs
self.refer_seg_ds = refer_seg_ds
self.data_type = "refer_seg"
self.ds = ds
self.image_size = image_size
self.tokenizer = tokenizer
self.transform = ResizeLongestSide(image_size)
self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower) if preprocessor_config == '' else CLIPImageProcessor.from_pretrained(preprocessor_config)
self.transform_clip = ResizeLongestSide(self.clip_image_processor.size['shortest_edge'])
def __len__(self):
if self.data_type == "refer_seg":
return len(self.refer_seg_ds["images"])
elif self.data_type == "custom_seg":
return len(self.custom_dataset)
else:
return len(self.images)
def preprocess(self, x: torch.Tensor, decoder_image_size) -> torch.Tensor:
"""Normalize pixel values and pad to a square input."""
x = (x - self.pixel_mean) / self.pixel_std
h, w = x.shape[-2:]
padh = decoder_image_size - h
padw = decoder_image_size - w
x = F.pad(x, (0, padw, 0, padh))
return x
def __getitem__(self, idx):
if self.data_type == "custom_seg":
return self.custom_dataset[idx]
elif self.data_type == "refer_seg":
refer_seg_ds = self.refer_seg_ds
images = refer_seg_ds["images"]
annotations = refer_seg_ds["annotations"]
img2refs = refer_seg_ds["img2refs"]
image_info = images[idx]
image_path = image_info["file_name"]
image_id = image_info["id"]
refs = img2refs[image_id]
if len(refs) == 0:
raise ValueError("image {} has no refs".format(image_id))
sents = []
ann_ids = []
for ref in refs:
for sent in ref["sentences"]:
sents.append(sent["sent"].strip().lower())
ann_ids.append(ref["ann_id"])
sampled_sents = sents
sampled_ann_ids = ann_ids
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
is_sentence = False
else:
image_path = self.images[idx]
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
json_path = image_path.replace(".jpg", ".json")
mask_json, sampled_sents, is_sentence = get_mask_from_json(json_path, image)
sampled_sents = [sampled_sents[0]]
conversations = []
conv = conversation_lib.default_conversation.copy()
i = 0
_seg = "[SEG]" if self.seg_token_num == 1 else ' '.join(["[SEG{}]".format(i) for i in range(self.seg_token_num)])
multi_sample_num = [6, 5, 4]
multi_sample_index = 0
while i < len(sampled_sents):
conv.messages = []
if self.multiseg_inference:
sample_num = multi_sample_num[multi_sample_index]
texts = [sampled_sents[k].strip() for k in range(i, i+sample_num)] if len(sampled_sents) - i >= sample_num else [sampled_sents[k].strip() for k in range(i, len(sampled_sents))]
text = ', '.join(texts[:-1]) + ' and {}'.format(texts[-1]) if len(texts) > 1 else texts[0]
else:
text = sampled_sents[i].strip()
if is_sentence:
conv.append_message(
conv.roles[0],
DEFAULT_IMAGE_TOKEN
+ "\n {} Please output segmentation mask.".format(text),
)
conv.append_message(conv.roles[1], "{}.".format(_seg))
else:
conv.append_message(
conv.roles[0],
DEFAULT_IMAGE_TOKEN
+ "\n What is {} in this image? Please output segmentation mask.".format(
text
),
)
if self.multiseg_inference:
answer = [_seg] * len(texts)
answer = ', '.join(answer[:-1]) + ' and ' + answer[-1] + '.' if len(answer) > 1 else answer[0]
conv.append_message(conv.roles[1], answer)
else:
conv.append_message(conv.roles[1], "{}.".format(_seg))
conversations.append(conv.get_prompt())
if self.multiseg_inference:
i += sample_num
multi_sample_index = (multi_sample_index + 1) % len(multi_sample_num)
else:
i += 1
if self.pad_val_clip_images:
image_clip = self.transform_clip.apply_image(image)
clip_resize = image_clip.shape[:2]
image_clip = self.preprocess(torch.from_numpy(image_clip).permute(2, 0, 1).contiguous(), self.clip_image_processor.size['shortest_edge'])
else:
image_clip = self.clip_image_processor.preprocess(image, return_tensors="pt")[
"pixel_values"
][0]
clip_resize = image_clip.shape[-2:]
image = self.transform.apply_image(image)
resize = image.shape[:2]
image = self.preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous(), self.img_size)
if self.data_type == "refer_seg":
masks = []
for i, ann_id in enumerate(sampled_ann_ids):
ann = annotations[ann_id]
if len(ann["segmentation"]) == 0 and sampled_sents[i] != "":
m = np.zeros((image_info["height"], image_info["width"], 1))
else:
if type(ann["segmentation"][0]) == list:
rle = mask.frPyObjects(
ann["segmentation"],
image_info["height"],
image_info["width"],
)
else:
rle = ann["segmentation"]
for i in range(len(rle)):
if not isinstance(rle[i]["counts"], bytes):
rle[i]["counts"] = rle[i]["counts"].encode()
m = mask.decode(rle)
m = np.sum(
m, axis=2
)
m = m.astype(np.uint8)
masks.append(m)
else:
masks = [mask_json]
masks = np.stack(masks, axis=0)
masks = torch.from_numpy(masks)
labels = torch.ones(masks.shape[1], masks.shape[2]) * self.ignore_label
inference = True
if self.masks_process_with_clip:
mask_shape = image_clip.shape[-1]
if len(masks) == 0:
masks = torch.zeros(0, mask_shape, mask_shape)
else:
masks = transform_mask(masks, mask_shape)
return (
image_path,
image,
image_clip,
conversations,
masks,
labels,
resize,
clip_resize,
None,
sampled_sents,
False,
inference,
)
def transform_mask(masks, size):
height, width = masks.shape[-2:]
short, long = (width, height) if width <= height else (height, width)
requested_new_short = size
new_short, new_long = requested_new_short, int(requested_new_short * long / short)
new_shape = (new_long, new_short) if width <= height else (new_short, new_long)
masks = F.interpolate(masks[None].float(), size=new_shape, mode="nearest")[0].bool()
orig_height, orig_width = new_shape
crop_height, crop_width = size, size
crop_height, crop_width = int(crop_height), int(crop_width)
top = (orig_height - crop_height) // 2
bottom = top + crop_height
left = (orig_width - crop_width) // 2
right = left + crop_width
assert top >= 0 and bottom <= orig_height and left >= 0 and right <= orig_width
masks = masks[..., top:bottom, left:right]
return masks
def center_crop_image(image, size):
orig_height, orig_width = image.shape[:2]
crop_height, crop_width = size, size
crop_height, crop_width = int(crop_height), int(crop_width)
top = (orig_height - crop_height) // 2
bottom = top + crop_height
left = (orig_width - crop_width) // 2
right = left + crop_width
assert top >= 0 and bottom <= orig_height and left >= 0 and right <= orig_width
image = image[top:bottom, left:right]
return image