lyhisme's picture
Upload 151 files
c02d17f verified
import logging
import os
import random
from dataclasses import dataclass
from multiprocessing import Value
import numpy as np
from training.utils import mask2box
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from open_clip.transform import get_scale
from pycocotools.coco import COCO
from training.coco_api import COCOPanoptic
from panopticapi import utils
import io
# from mmengine.fileio import get
try:
from petrel_client.client import Client
except:
Client = None
from open_clip.transform import ResizeLongest
# import image transforms
from torchvision.transforms import RandomHorizontalFlip, Compose
from training.custom_transforms import CustomRandomResize, CustomRandomCrop
class ProposalDistillDataset(Dataset):
def __init__(self, input_filename, transforms, image_root,
crop_size=224,
tokenizer=None, args=None):
logging.debug(f'Loading coco style data from {input_filename}.')
self.coco = COCO(input_filename)
logging.debug('Done loading data.')
self.transforms = transforms
self.tokenize = tokenizer
self.image_root = image_root
self.image_ids = list(self.coco.imgs.keys())
self.max_anns = 20
if not isinstance(crop_size, (tuple, list)):
crop_size = [crop_size, crop_size]
self.crop_size = crop_size
self.args = args
self.min_size = args.min_size
self.max_size = args.max_size
self.ceph_root = args.train_ceph_root
self.use_ceph = (self.ceph_root != "")
self.FILE_CLIENT = None
def read_image(self, image_name):
if self.use_ceph:
image_path = os.path.join(self.ceph_root, image_name)
if self.FILE_CLIENT is None:
self.FILE_CLIENT = Client()
try:
img_bytes = self.FILE_CLIENT.get(image_path)
buff = io.BytesIO(img_bytes)
image = Image.open(buff)
except:
print(f"Cannot load {image_path}", flush=True)
return None
else:
image_path = os.path.join(self.image_root, image_name)
try:
image = Image.open(image_path)
except:
print(f"Cannot load {image_path}", flush=True)
return None
width, height = image.size
if width < 10 or height < 10:
print(f"Invalid image, size {image.size}", flush=True)
return None
return image
def __len__(self):
return len(self.image_ids)
def __getitem__(self, idx):
image_id = self.image_ids[idx]
image_info = self.coco.imgs[image_id]
if 'file_name' in image_info:
image_name = image_info['file_name']
else:
assert 'coco_url' in image_info
coco_url = image_info['coco_url'].split('/')
image_name = os.path.join(coco_url[-2], coco_url[-1])
old_image = self.read_image(image_name)
if old_image is None:
next_id = random.choice(range(self.__len__()))
return self.__getitem__(next_id)
img_w, img_h = old_image.width, old_image.height
new_image = self.transforms[0](old_image)
scale = get_scale(old_image, new_image)
anns = self.coco.imgToAnns[image_id]
boxes_template = torch.zeros(self.max_anns, 4 + 1) # xyxy s
image_crops = torch.zeros(self.max_anns, 3, *self.crop_size)
indices = list(range(len(anns)))
random.shuffle(indices)
num_valid_boxes = 0
for i, ann_id in enumerate(indices[:self.max_anns]):
ann = anns[ann_id]
x, y, w, h = ann['bbox']
if w*h < (self.min_size ** 2) or w*h > (self.max_size ** 2):
continue
num_valid_boxes += 1
cx, cy = x + w*0.5, y + h*0.5
x0, y0, x1, y1 = \
max(cx - w*0.75, 0), max(cy - h*0.75, 0), min(cx + w*0.75, img_w), min(cy + h*0.75, img_h)
image_crops[i] = self.transforms[1](old_image.crop((x0, y0, x1, y1))) # image crops
box_info = torch.tensor([x, y, x + w, y + h, 1.0]) # x, y, x + w, y + h
boxes_template[i] = box_info
if num_valid_boxes == 0:
boxes_template[0] = torch.tensor([0, 0, img_w / 4, img_h / 4, 1.0]) # avoid empty
image_crops[0] = self.transforms[1](old_image.crop((0, 0, img_w // 4, img_h // 4)))
_, h, w = new_image.shape
boxes_template[:, :4] *= scale
boxes_template[:, [0, 2]] /= w
boxes_template[:, [1, 3]] /= h
return new_image, boxes_template, image_crops
class GridDistillDataset(Dataset):
def __init__(self,
input_filename, transforms, image_root,
max_split=16,
crop_size=224,
pre_transforms=False,
ceph_root="", args=None):
self._init_choices(max_split)
logging.debug(f'Loading coco caption style data from {input_filename}.')
self.coco = COCO(input_filename)
logging.debug('Done loading data.')
self.transforms = transforms
self.image_root = image_root
self.args = args
image_ids = list(self.coco.imgs.keys())
train_ratio = args.train_ratio
if train_ratio < 1.0:
num_images = int(len(image_ids) * train_ratio)
random.shuffle(image_ids)
image_ids = image_ids[:num_images]
self.image_ids = image_ids
self.max_anns = args.max_boxes
if not isinstance(crop_size, (tuple, list)):
crop_size = [crop_size, crop_size]
self.crop_size = crop_size
self._init_boxes()
self.ceph_root = ceph_root
self.use_ceph = (ceph_root != "")
self.FILE_CLIENT = None
if pre_transforms:
self.pre_transforms = Compose([
CustomRandomResize(scale=(0.5, 2.0)),
CustomRandomCrop(size=self.transforms[0].transforms[0].max_size),
RandomHorizontalFlip()])
else:
self.pre_transforms = None
def read_image(self, image_name):
if self.use_ceph:
image_path = os.path.join(self.ceph_root, image_name)
if self.FILE_CLIENT is None:
self.FILE_CLIENT = Client()
try:
img_bytes = self.FILE_CLIENT.get(image_path)
buff = io.BytesIO(img_bytes)
image = Image.open(buff)
except:
print(f"Cannot load {image_path}", flush=True)
return None
else:
image_path = os.path.join(self.image_root, image_name)
try:
image = Image.open(image_path)
except:
print(f"Cannot load {image_path}", flush=True)
return None
width, height = image.size
if width < 10 or height < 10:
print(f"Invalid image, size {image.size}", flush=True)
return None
return image
def _init_choices(self, M=16):
choices = []
for m in range(2, M+1):
for n in range((m + 1)//2+1, min(m*2 + 1, M+1)):
choices.append((m, n))
self.choices = choices
def __len__(self):
return len(self.image_ids)
def _init_boxes(self, ):
box_templates = {}
for choice in self.choices:
M, N = choice
grid_x, grid_y = torch.meshgrid(torch.linspace(0, 1, N + 1), torch.linspace(0, 1, M + 1),
indexing='xy')
x0y0s = torch.stack([grid_x[:M, :N], grid_y[:M, :N]], dim=-1)
x1y1s = torch.stack([grid_x[1:, 1:], grid_y[1:, 1:]], dim=-1)
pseudo_boxes = torch.cat([x0y0s, x1y1s],
dim=-1).view(-1, 4)
assert pseudo_boxes.shape[0] == M*N
box_templates[choice] = pseudo_boxes
self.box_templates = box_templates
def _obtain_image_crops(self, image, choice):
image_crops = []
img_w, img_h = image.size
normed_boxes = self.box_templates[choice]
indices = list(range(len(normed_boxes)))
random.shuffle(indices)
indices = indices[:self.max_anns]
boxes = normed_boxes * torch.tensor([img_w, img_h, img_w, img_h])
for idx in indices:
box = boxes[idx]
x0, y0, x1, y1 = box.tolist() # todo expand
if self.args.crop_scale > 1.0:
box_w, box_h = x1 - x0, y1 - y0
cx, cy = (x1 + x0)/2, (y1 + y0)/2
delta_factor = 0.5 * self.args.crop_scale
x0, y0, x1, y1 = max(cx - box_w * delta_factor, 0), max(cy - box_h * delta_factor, 0), \
min(cx + box_w * delta_factor, img_w), min(cy + box_h * delta_factor, img_h)
image_crops.append(self.transforms[1](image.crop((x0, y0, x1, y1))))
return torch.stack(image_crops), boxes[indices]
def __getitem__(self, idx):
image_id = self.image_ids[idx]
image_info = self.coco.imgs[image_id]
if 'file_name' in image_info:
image_name = image_info['file_name']
else:
assert 'coco_url' in image_info
coco_url = image_info['coco_url'].split('/')
image_name = os.path.join(coco_url[-2], coco_url[-1])
# image_path = os.path.join(self.image_root, image_name)
# old_image = Image.open(image_path)
old_image = self.read_image(image_name)
if old_image is None:
next_id = random.choice(range(self.__len__()))
return self.__getitem__(next_id)
new_image = self.transforms[0](old_image)
scale = get_scale(old_image, new_image)
boxes_template = torch.zeros(self.max_anns, 4 + 1) # xyxy s
image_crops_template = torch.zeros(self.max_anns, 3, *self.crop_size)
image_crops, boxes = self._obtain_image_crops(old_image,
random.choice(self.choices))
assert image_crops.shape[0] == boxes.shape[0]
_, h, w = new_image.shape
boxes[:, :4] *= scale
boxes[:, [0, 2]] /= w
boxes[:, [1, 3]] /= h
boxes_template[:boxes.shape[0], :4] = boxes
boxes_template[:boxes.shape[0], 4] = 1.0
image_crops_template[:boxes.shape[0]] = image_crops
return new_image, boxes_template, image_crops_template
class COCOPanopticDataset(Dataset):
def __init__(self, input_filename, transforms, image_root, embed_path,
segm_root,
crop_size=224,
tokenizer=None,
downsample_factor=16,
min_size=8, max_size=1024):
logging.debug(f'Loading coco caption style data from {input_filename}.')
self.coco = COCOPanoptic(input_filename)
logging.debug('Done loading data.')
self.transforms = transforms
self.tokenize = tokenizer
self.image_root = image_root
self.embeddings = np.load(embed_path)
self.image_ids = list(self.coco.imgs.keys())
num_annos = [len(anns) for anns in self.coco.imgToAnns.values()]
self.max_anns = min(max(num_annos), 100)
if not isinstance(crop_size, (tuple, list)):
crop_size = [crop_size, crop_size]
self.crop_size = crop_size
self.min_size = 8 # fix for val
self.max_size = 1024
self.segm_root = segm_root
self.downsample_factor = downsample_factor
self.segm_transform = ResizeLongest(max_size=self.transforms[0].transforms[0].max_size // downsample_factor,
fill=0) # downsample to the output size of image encoder
cat_ids = sorted([cat['id'] for cat in self.coco.cats.values()])
self.cat_id2label = {cat_id: label for label, cat_id in enumerate(cat_ids)}
def __len__(self):
return len(self.image_ids)
@staticmethod
def _load_segm(segm_path):
segmentation = np.array(
Image.open(segm_path),
dtype=np.uint8
)
# img_bytes = get(segm_path)
# pan_png = mmcv.imfrombytes(
# img_bytes, flag='color', channel_order='rgb').squeeze()
segm_map = utils.rgb2id(segmentation)
return segm_map
def __getitem__(self, idx):
image_id = self.image_ids[idx]
image_info = self.coco.imgs[image_id]
image_name = image_info['file_name']
segm_file = image_info['segm_file']
image_path = os.path.join(self.image_root, image_name)
segm_path = os.path.join(self.segm_root, segm_file)
segm_map = self._load_segm(segm_path)
old_image = Image.open(image_path)
img_w, img_h = old_image.width, old_image.height
new_image = self.transforms[0](old_image)
scale = get_scale(old_image, new_image)
anns = self.coco.imgToAnns[image_id]
boxes_template = torch.zeros(self.max_anns, 4 + 2 + 1 + 1) # xyxy c valid size, isthing
image_crops = torch.zeros(self.max_anns, 3, *self.crop_size)
gt_masks = torch.zeros(self.max_anns, self.segm_transform.max_size,
self.segm_transform.max_size)
masked_image_crops = torch.zeros(self.max_anns, 3, *self.crop_size)
for i, ann in enumerate(anns):
if i == self.max_anns:
break
cat_id = ann['category_id']
is_thing = self.coco.cats[cat_id]['isthing']
if is_thing > 0:
x, y, w, h = ann['bbox']
cx, cy = x + w*0.5, y + h*0.5
x0, y0, x1, y1 = \
max(cx - w*0.75, 0), max(cy - h*0.75, 0), min(cx + w*0.75, img_w), min(cy + h*0.75, img_h)
else:
x0, y0, x1, y1 = mask2box(segm_map == ann['id'])
x, y, w, h = x0, y0, x1 - x0, y1 - y0
if w * h < (self.min_size ** 2) or w * h > (self.max_size ** 2):
continue
image_crops[i] = self.transforms[1](old_image.crop((x0, y0, x1, y1))) # image crops
# masked image crop
np_old_image = np.asarray(old_image.copy()).copy()
np_old_image[segm_map != ann['id']] = 114
masked_old_image = Image.fromarray(np_old_image)
masked_image_crops[i] = self.transforms[1](masked_old_image.crop((x0, y0, x1, y1))) # image crops
gt_mask = torch.from_numpy(segm_map == ann['id']).float()
gt_mask = self.segm_transform(gt_mask[None]) > 0.0
cls_label = self.cat_id2label[cat_id]
box_info = torch.tensor([x, y, x + w, y + h, cls_label, 1.0, w * h, is_thing]) # x, y, x + w, y + h
boxes_template[i] = box_info
gt_masks[i] = gt_mask[0]
_, h, w = new_image.shape
boxes_template[:, :4] *= scale
boxes_template[:, [0, 2]] /= w
boxes_template[:, [1, 3]] /= h
return new_image, boxes_template, image_crops, gt_masks, masked_image_crops
class ADEPanopticDataset(Dataset):
def __init__(self, input_filename, transforms, image_root, embed_path,
segm_root,
crop_size=224,
tokenizer=None,
downsample_factor=16,
min_size=8, max_size=1024):
logging.debug(f'Loading coco caption style data from {input_filename}.')
self.coco = COCOPanoptic(input_filename)
logging.debug('Done loading data.')
self.transforms = transforms
self.tokenize = tokenizer
self.image_root = image_root
self.embeddings = np.load(embed_path)
self.image_ids = list(self.coco.imgs.keys())
num_annos = [len(anns) for anns in self.coco.imgToAnns.values()]
self.max_anns = min(max(num_annos), 100)
if not isinstance(crop_size, (tuple, list)):
crop_size = [crop_size, crop_size]
self.crop_size = crop_size
self.min_size = 8 # fix for val
self.max_size = 1024
self.segm_root = segm_root
self.downsample_factor = downsample_factor
self.segm_transform = ResizeLongest(max_size=self.transforms[0].transforms[0].max_size // downsample_factor,
fill=0) # downsample to the output size of image encoder
cat_ids = sorted([cat['id'] for cat in self.coco.cats.values()])
self.cat_id2label = {cat_id: label for label, cat_id in enumerate(cat_ids)}
def __len__(self):
return len(self.image_ids)
@staticmethod
def _load_segm(segm_path):
segmentation = np.array(
Image.open(segm_path),
dtype=np.uint8
)
# img_bytes = get(segm_path)
# pan_png = mmcv.imfrombytes(
# img_bytes, flag='color', channel_order='rgb').squeeze()
segm_map = utils.rgb2id(segmentation)
return segm_map
def __getitem__(self, idx):
image_id = self.image_ids[idx]
image_info = self.coco.imgs[image_id]
image_name = image_info['file_name']
segm_file = image_info['segm_file']
image_path = os.path.join(self.image_root, image_name)
segm_path = os.path.join(self.segm_root, segm_file)
segm_map = self._load_segm(segm_path)
old_image = Image.open(image_path)
img_w, img_h = old_image.width, old_image.height
new_image = self.transforms[0](old_image)
scale = get_scale(old_image, new_image)
anns = self.coco.imgToAnns[image_id]
boxes_template = torch.zeros(self.max_anns, 4 + 2 + 1 + 1) # xyxy c valid size, isthing
image_crops = torch.zeros(self.max_anns, 3, *self.crop_size)
gt_masks = torch.zeros(self.max_anns, self.segm_transform.max_size,
self.segm_transform.max_size)
masked_image_crops = torch.zeros(self.max_anns, 3, *self.crop_size)
for i, ann in enumerate(anns):
if i == self.max_anns:
break
cat_id = ann['category_id']
is_thing = self.coco.cats[cat_id]['isthing']
if is_thing > 0:
x, y, w, h = ann['bbox']
cx, cy = x + w*0.5, y + h*0.5
x0, y0, x1, y1 = \
max(cx - w*0.75, 0), max(cy - h*0.75, 0), min(cx + w*0.75, img_w), min(cy + h*0.75, img_h)
else:
x0, y0, x1, y1 = mask2box(segm_map == ann['id'])
x, y, w, h = x0, y0, x1 - x0, y1 - y0
if w * h < (self.min_size ** 2) or w * h > (self.max_size ** 2):
continue
image_crops[i] = self.transforms[1](old_image.crop((x0, y0, x1, y1))) # image crops
# masked image crop
np_old_image = np.asarray(old_image.copy())
np_old_image = np_old_image.copy()
np_old_image[segm_map != ann['id']] = 114
masked_old_image = Image.fromarray(np_old_image)
masked_image_crops[i] = self.transforms[1](masked_old_image.crop((x0, y0, x1, y1))) # image crops
gt_mask = torch.from_numpy(segm_map == ann['id']).float()
gt_mask = self.segm_transform(gt_mask[None]) > 0.0
cls_label = self.cat_id2label[cat_id]
box_info = torch.tensor([x, y, x + w, y + h, cls_label, 1.0, w * h, is_thing]) # x, y, x + w, y + h
boxes_template[i] = box_info
gt_masks[i] = gt_mask[0]
_, h, w = new_image.shape
boxes_template[:, :4] *= scale
boxes_template[:, [0, 2]] /= w
boxes_template[:, [1, 3]] /= h
return new_image, boxes_template, image_crops, gt_masks, masked_image_crops
class COCORegionCLIPDataset(Dataset):
def __init__(self, input_filename, transforms, image_root, args):
logging.debug(f'Loading coco caption style data from {input_filename}.')
self.coco = COCO(input_filename)
logging.debug('Done loading data.')
self.transforms = transforms
self.image_root = image_root
image_ids = list(self.coco.imgToAnns.keys()) # only use images that have anns
train_ratio = args.train_ratio
if train_ratio < 1.0:
num_images = int(len(image_ids) * train_ratio)
random.shuffle(image_ids)
image_ids = image_ids[:num_images]
self.image_ids = image_ids
num_annos = [len(anns) for anns in self.coco.imgToAnns.values()]
self.max_anns = min(max(num_annos), 20)
self.args = args
self.ceph_root = args.train_ceph_root
self.use_ceph = (self.ceph_root != "")
self.FILE_CLIENT = None
cat_ids = sorted([cat['id'] for cat in self.coco.cats.values()])
self.cat_id2label = {cat_id: label for label, cat_id in enumerate(cat_ids)}
def __len__(self):
return len(self.image_ids)
def read_image(self, image_name):
if self.use_ceph:
image_path = os.path.join(self.ceph_root, image_name)
if self.FILE_CLIENT is None:
self.FILE_CLIENT = Client()
img_bytes = self.FILE_CLIENT.get(image_path)
buff = io.BytesIO(img_bytes)
image = Image.open(buff)
else:
image_path = os.path.join(self.image_root, image_name)
image = Image.open(image_path)
return image
def __getitem__(self, idx):
image_id = self.image_ids[idx]
image_info = self.coco.imgs[image_id]
image_name = image_info['file_name']
# image_path = os.path.join(self.image_root, image_name)
# old_image = Image.open(image_path)
old_image = self.read_image(image_name)
new_image = self.transforms[0](old_image)
scale = get_scale(old_image, new_image)
anns = self.coco.imgToAnns[image_id]
boxes_template = torch.zeros(self.max_anns, 4 + 2) # xyxy cls valid
for i, ann in enumerate(anns):
if i == self.max_anns:
break
cat_id = ann['category_id']
x, y, w, h = ann['bbox']
cls_label = self.cat_id2label[cat_id]
box_info = torch.tensor([x, y, x + w, y + h, cls_label, 1.0]) # x, y, x + w, y + h
boxes_template[i] = box_info
_, h, w = new_image.shape
boxes_template[:, :4] *= scale
boxes_template[:, [0, 2]] /= w
boxes_template[:, [1, 3]] /= h
return new_image, boxes_template
def get_coco_panoptic_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None):
input_filename = args.train_data if is_train else args.val_data
assert input_filename
dataset = COCOPanopticDataset(
input_filename,
preprocess_fn,
segm_root=args.val_segm_root,
image_root=args.val_image_root,
embed_path=args.embed_path,
tokenizer=tokenizer,
crop_size=args.input_size,
min_size=args.min_size,
max_size=args.max_size,
downsample_factor=args.downsample_factor
)
num_samples = len(dataset)
# TODO: distributed for test
sampler = DistributedSampler(dataset) if args.distributed else None # and is_train else None
shuffle = is_train and sampler is None
if is_train:
batch_size = args.batch_size
else:
batch_size = min(args.batch_size, 1) # only support bs = 1 for inference
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=args.workers,
pin_memory=True,
sampler=sampler,
drop_last=is_train,
)
dataloader.num_samples = num_samples
dataloader.num_batches = len(dataloader)
return DataInfo(dataloader, sampler)
def get_ade_panoptic_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None):
input_filename = args.train_data if is_train else args.val_data
assert input_filename
dataset = ADEPanopticDataset(
input_filename,
preprocess_fn,
segm_root=args.val_segm_root,
image_root=args.val_image_root,
embed_path=args.embed_path,
tokenizer=tokenizer,
crop_size=args.input_size,
min_size=args.min_size,
max_size=args.max_size,
downsample_factor=args.downsample_factor
)
num_samples = len(dataset)
# TODO: distributed for test
sampler = DistributedSampler(dataset) if args.distributed else None # and is_train else None
shuffle = is_train and sampler is None
if is_train:
batch_size = args.batch_size
else:
batch_size = min(args.batch_size, 1) # only support bs = 1 for inference
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=args.workers,
pin_memory=True,
sampler=sampler,
drop_last=is_train,
)
dataloader.num_samples = num_samples
dataloader.num_batches = len(dataloader)
return DataInfo(dataloader, sampler)
def get_proposal_distill_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None):
assert is_train
input_filename = args.train_data # if is_train else args.val_data
assert input_filename
dataset = ProposalDistillDataset(
input_filename,
preprocess_fn,
image_root=args.train_image_root,
tokenizer=tokenizer,
crop_size=args.input_size,
args=args
)
num_samples = len(dataset)
# TODO: distributed for test
sampler = DistributedSampler(dataset) if args.distributed else None # and is_train else None
shuffle = is_train and sampler is None
batch_size = args.batch_size
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=args.workers,
pin_memory=True,
sampler=sampler,
drop_last=is_train,
)
dataloader.num_samples = num_samples
dataloader.num_batches = len(dataloader)
return DataInfo(dataloader, sampler)
def get_grid_distill_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None):
assert is_train
input_filename = args.train_data
assert input_filename
dataset = GridDistillDataset(
input_filename=input_filename,
transforms=preprocess_fn,
image_root=args.train_image_root,
crop_size=args.input_size,
max_split=args.max_split,
ceph_root=args.train_ceph_root,
pre_transforms=args.pre_transforms,
args=args
)
num_samples = len(dataset)
# TODO: distributed for test
sampler = DistributedSampler(dataset) if args.distributed else None # and is_train else None
shuffle = is_train and sampler is None
batch_size = args.batch_size
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=args.workers,
pin_memory=True,
sampler=sampler,
drop_last=is_train,
)
dataloader.num_samples = num_samples
dataloader.num_batches = len(dataloader)
return DataInfo(dataloader, sampler)
def get_region_clip_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None):
assert is_train
input_filename = args.train_data
assert input_filename
dataset = COCORegionCLIPDataset(
input_filename=input_filename,
transforms=preprocess_fn,
image_root=args.train_image_root,
args=args,
)
num_samples = len(dataset)
# TODO: distributed for test
sampler = DistributedSampler(dataset) if args.distributed else None # and is_train else None
shuffle = is_train and sampler is None
batch_size = args.batch_size
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=args.workers,
pin_memory=True,
sampler=sampler,
drop_last=is_train,
)
dataloader.num_samples = num_samples
dataloader.num_batches = len(dataloader)
return DataInfo(dataloader, sampler)
class SharedEpoch:
def __init__(self, epoch: int = 0):
self.shared_epoch = Value('i', epoch)
def set_value(self, epoch):
self.shared_epoch.value = epoch
def get_value(self):
return self.shared_epoch.value
@dataclass
class DataInfo:
dataloader: DataLoader
sampler: DistributedSampler = None
shared_epoch: SharedEpoch = None
def set_epoch(self, epoch):
if self.shared_epoch is not None:
self.shared_epoch.set_value(epoch)
if self.sampler is not None and isinstance(self.sampler, DistributedSampler):
self.sampler.set_epoch(epoch)
def get_dataset_fn(data_path, dataset_type):
if dataset_type == 'coco_panoptic':
return get_coco_panoptic_dataset
elif dataset_type == 'ade_panoptic':
return get_ade_panoptic_dataset
elif dataset_type == 'proposals_distill':
return get_proposal_distill_dataset
elif dataset_type == 'grid_distill':
return get_grid_distill_dataset
elif dataset_type == 'region_clip':
return get_region_clip_dataset
else:
raise ValueError(f"Unsupported dataset type: {dataset_type}")
def get_data(args, preprocess_fns, epoch=0, tokenizer=None):
preprocess_train, preprocess_val = preprocess_fns
data = {}
if args.train_data:
data["train"] = get_dataset_fn(args.train_data, args.dataset_type)(
args, preprocess_train, is_train=True, epoch=epoch, tokenizer=tokenizer)
if args.val_data:
data["val"] = get_dataset_fn(args.val_data, dataset_type=args.test_type)(
args, preprocess_val, is_train=False, tokenizer=tokenizer)
return data