import os from os import path, replace import torch from torch.utils.data.dataset import Dataset from torchvision import transforms from torchvision.transforms import InterpolationMode from PIL import Image import numpy as np from dataset.range_transform import im_normalization, im_mean, im_rgb2lab_normalization, ToTensor, RGB2Lab from dataset.reseed import reseed import util.functional as F class VOSDataset_221128_TransColorization_batch(Dataset): """ Works for DAVIS/YouTubeVOS/BL30K training For each sequence: - Pick three frames - Pick two objects - Apply some random transforms that are the same for all frames - Apply random transform to each of the frame - The distance between frames is controlled """ def __init__(self, im_root, gt_root, max_jump, is_bl, subset=None, num_frames=3, max_num_obj=2, finetune=False): self.im_root = im_root self.gt_root = gt_root self.max_jump = max_jump self.is_bl = is_bl self.num_frames = num_frames self.max_num_obj = max_num_obj self.videos = [] self.frames = {} vid_list = sorted(os.listdir(self.im_root)) # Pre-filtering for vid in vid_list: if subset is not None: if vid not in subset: continue frames = sorted(os.listdir(os.path.join(self.im_root, vid))) if len(frames) < num_frames: continue self.frames[vid] = frames self.videos.append(vid) print('%d out of %d videos accepted in %s.' % (len(self.videos), len(vid_list), im_root)) # These set of transform is the same for im/gt pairs, but different among the 3 sampled frames self.pair_im_lone_transform = transforms.Compose([ transforms.ColorJitter(0.01, 0.01, 0.01, 0), ]) self.pair_im_dual_transform = transforms.Compose([ transforms.RandomAffine(degrees=0 if finetune or self.is_bl else 15, shear=0 if finetune or self.is_bl else 10, interpolation=InterpolationMode.BILINEAR, fill=im_mean), ]) self.pair_gt_dual_transform = transforms.Compose([ transforms.RandomAffine(degrees=0 if finetune or self.is_bl else 15, shear=0 if finetune or self.is_bl else 10, interpolation=InterpolationMode.NEAREST, fill=0), ]) # These transform are the same for all pairs in the sampled sequence self.all_im_lone_transform = transforms.Compose([ transforms.ColorJitter(0.1, 0.03, 0.03, 0), # transforms.RandomGrayscale(0.05), ]) patchsz = 448 # 224 self.all_im_dual_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomResizedCrop((patchsz, patchsz), scale=(0.36,1.00), interpolation=InterpolationMode.BILINEAR) ]) self.all_gt_dual_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomResizedCrop((patchsz, patchsz), scale=(0.36,1.00), interpolation=InterpolationMode.NEAREST) ]) # Final transform without randomness self.final_im_transform = transforms.Compose([ RGB2Lab(), ToTensor(), im_rgb2lab_normalization, ]) def __getitem__(self, idx): video = self.videos[idx] info = {} info['name'] = video vid_im_path = path.join(self.im_root, video) vid_gt_path = path.join(self.gt_root, video) frames = self.frames[video] trials = 0 while trials < 5: info['frames'] = [] # Appended with actual frames num_frames = self.num_frames length = len(frames) this_max_jump = min(len(frames), self.max_jump) # iterative sampling frames_idx = [np.random.randint(length)] acceptable_set = set(range(max(0, frames_idx[-1]-this_max_jump), min(length, frames_idx[-1]+this_max_jump+1))).difference(set(frames_idx)) while(len(frames_idx) < num_frames): idx = np.random.choice(list(acceptable_set)) frames_idx.append(idx) new_set = set(range(max(0, frames_idx[-1]-this_max_jump), min(length, frames_idx[-1]+this_max_jump+1))) acceptable_set = acceptable_set.union(new_set).difference(set(frames_idx)) frames_idx = sorted(frames_idx) if np.random.rand() < 0.5: # Reverse time frames_idx = frames_idx[::-1] sequence_seed = np.random.randint(2147483647) images = [] masks = [] target_objects = [] for f_idx in frames_idx: jpg_name = frames[f_idx] png_name = jpg_name.replace('.jpg', '.png') info['frames'].append(jpg_name) reseed(sequence_seed) this_im = Image.open(path.join(vid_im_path, jpg_name)).convert('RGB') this_im = self.all_im_dual_transform(this_im) this_im = self.all_im_lone_transform(this_im) reseed(sequence_seed) this_gt = Image.open(path.join(vid_gt_path, png_name)).convert('P') this_gt = self.all_gt_dual_transform(this_gt) pairwise_seed = np.random.randint(2147483647) reseed(pairwise_seed) this_im = self.pair_im_dual_transform(this_im) this_im = self.pair_im_lone_transform(this_im) reseed(pairwise_seed) this_gt = self.pair_gt_dual_transform(this_gt) this_im = self.final_im_transform(this_im) # print('1', torch.max(this_im[:1,:,:]), torch.min(this_im[:1,:,:])) # print('2', torch.max(this_im[1:3,:,:]), torch.min(this_im[1:3,:,:])) # print('3', torch.max(this_im), torch.min(this_im));assert 1==0 # print(this_im.size());assert 1==0 this_gt = np.array(this_gt) this_im_l = this_im[:1,:,:] this_im_ab = this_im[1:3,:,:] # print(this_im_l.size(), this_im_ab.size());assert 1==0 # images.append(this_im_l) # masks.append(this_im_ab) this_im_lll = this_im_l.repeat(3,1,1) images.append(this_im_lll) masks.append(this_im_ab) images = torch.stack(images, 0) # print(images.size());assert 1==0 # target_objects = labels.tolist() break first_frame_gt = masks[0].unsqueeze(0) # print(first_frame_gt.size());assert 1==0 info['num_objects'] = 2 masks = np.stack(masks, 0) # print(np.shape(masks));assert 1==0 cls_gt = masks # # Generate one-hot ground-truth # cls_gt = np.zeros((self.num_frames, 384, 384), dtype=np.int) # first_frame_gt = np.zeros((1, self.max_num_obj, 384, 384), dtype=np.int) # for i, l in enumerate(target_objects): # this_mask = (masks==l) # cls_gt[this_mask] = i+1 # first_frame_gt[0,i] = (this_mask[0]) # cls_gt = np.expand_dims(cls_gt, 1) # 1 if object exist, 0 otherwise selector = [1 if i < info['num_objects'] else 0 for i in range(self.max_num_obj)] # print(info['num_objects'], self.max_num_obj, selector);assert 1==0 selector = torch.FloatTensor(selector) # print(images.size(), np.shape(first_frame_gt), np.shape(cls_gt));assert 1==0 ### torch.Size([8, 3, 384, 384]) torch.Size([1, 2, 384, 384]) (8, 2, 384, 384) data = { 'rgb': images, 'first_frame_gt': first_frame_gt, 'cls_gt': cls_gt, 'selector': selector, 'info': info, } return data def __len__(self): return len(self.videos)