import os from os import path import torch from torch.utils.data.dataset import Dataset from torchvision import transforms from torchvision.transforms import InterpolationMode from PIL import Image import numpy as np from dataset.range_transform import im_normalization, im_mean from dataset.tps import random_tps_warp from dataset.reseed import reseed class StaticTransformDataset(Dataset): """ Generate pseudo VOS data by applying random transforms on static images. Single-object only. Method 0 - FSS style (class/1.jpg class/1.png) Method 1 - Others style (XXX.jpg XXX.png) """ def __init__(self, parameters, num_frames=3, max_num_obj=1): self.num_frames = num_frames self.max_num_obj = max_num_obj self.im_list = [] for parameter in parameters: root, method, multiplier = parameter if method == 0: # Get images classes = os.listdir(root) for c in classes: imgs = os.listdir(path.join(root, c)) jpg_list = [im for im in imgs if 'jpg' in im[-3:].lower()] joint_list = [path.join(root, c, im) for im in jpg_list] self.im_list.extend(joint_list * multiplier) elif method == 1: self.im_list.extend([path.join(root, im) for im in os.listdir(root) if '.jpg' in im] * multiplier) print(f'{len(self.im_list)} images found.') # These set of transform is the same for im/gt pairs, but different among the 3 sampled frames self.pair_im_lone_transform = transforms.Compose([ transforms.ColorJitter(0.1, 0.05, 0.05, 0), # No hue change here as that's not realistic ]) self.pair_im_dual_transform = transforms.Compose([ transforms.RandomAffine(degrees=20, scale=(0.9,1.1), shear=10, interpolation=InterpolationMode.BICUBIC, fill=im_mean), transforms.Resize(384, InterpolationMode.BICUBIC), transforms.RandomCrop((384, 384), pad_if_needed=True, fill=im_mean), ]) self.pair_gt_dual_transform = transforms.Compose([ transforms.RandomAffine(degrees=20, scale=(0.9,1.1), shear=10, interpolation=InterpolationMode.BICUBIC, fill=0), transforms.Resize(384, InterpolationMode.NEAREST), transforms.RandomCrop((384, 384), pad_if_needed=True, fill=0), ]) # These transform are the same for all pairs in the sampled sequence self.all_im_lone_transform = transforms.Compose([ transforms.ColorJitter(0.1, 0.05, 0.05, 0.05), transforms.RandomGrayscale(0.05), ]) self.all_im_dual_transform = transforms.Compose([ transforms.RandomAffine(degrees=0, scale=(0.8, 1.5), fill=im_mean), transforms.RandomHorizontalFlip(), ]) self.all_gt_dual_transform = transforms.Compose([ transforms.RandomAffine(degrees=0, scale=(0.8, 1.5), fill=0), transforms.RandomHorizontalFlip(), ]) # Final transform without randomness self.final_im_transform = transforms.Compose([ transforms.ToTensor(), im_normalization, ]) self.final_gt_transform = transforms.Compose([ transforms.ToTensor(), ]) def _get_sample(self, idx): im = Image.open(self.im_list[idx]).convert('RGB') gt = Image.open(self.im_list[idx][:-3]+'png').convert('L') sequence_seed = np.random.randint(2147483647) images = [] masks = [] for _ in range(self.num_frames): reseed(sequence_seed) this_im = self.all_im_dual_transform(im) this_im = self.all_im_lone_transform(this_im) reseed(sequence_seed) this_gt = self.all_gt_dual_transform(gt) pairwise_seed = np.random.randint(2147483647) reseed(pairwise_seed) this_im = self.pair_im_dual_transform(this_im) this_im = self.pair_im_lone_transform(this_im) reseed(pairwise_seed) this_gt = self.pair_gt_dual_transform(this_gt) # Use TPS only some of the times # Not because TPS is bad -- just that it is too slow and I need to speed up data loading if np.random.rand() < 0.33: this_im, this_gt = random_tps_warp(this_im, this_gt, scale=0.02) this_im = self.final_im_transform(this_im) this_gt = self.final_gt_transform(this_gt) images.append(this_im) masks.append(this_gt) images = torch.stack(images, 0) masks = torch.stack(masks, 0) return images, masks.numpy() def __getitem__(self, idx): additional_objects = np.random.randint(self.max_num_obj) indices = [idx, *np.random.randint(self.__len__(), size=additional_objects)] merged_images = None merged_masks = np.zeros((self.num_frames, 384, 384), dtype=np.int) for i, list_id in enumerate(indices): images, masks = self._get_sample(list_id) if merged_images is None: merged_images = images else: merged_images = merged_images*(1-masks) + images*masks merged_masks[masks[:,0]>0.5] = (i+1) masks = merged_masks labels = np.unique(masks[0]) # Remove background labels = labels[labels!=0] target_objects = labels.tolist() # Generate one-hot ground-truth cls_gt = np.zeros((self.num_frames, 384, 384), dtype=np.int) first_frame_gt = np.zeros((1, self.max_num_obj, 384, 384), dtype=np.int) for i, l in enumerate(target_objects): this_mask = (masks==l) cls_gt[this_mask] = i+1 first_frame_gt[0,i] = (this_mask[0]) cls_gt = np.expand_dims(cls_gt, 1) info = {} info['name'] = self.im_list[idx] info['num_objects'] = max(1, len(target_objects)) # 1 if object exist, 0 otherwise selector = [1 if i < info['num_objects'] else 0 for i in range(self.max_num_obj)] selector = torch.FloatTensor(selector) data = { 'rgb': merged_images, 'first_frame_gt': first_frame_gt, 'cls_gt': cls_gt, 'selector': selector, 'info': info } return data def __len__(self): return len(self.im_list)