| import os | |
| import random | |
| from torch.utils.data import Dataset | |
| from PIL import Image | |
| from .augmentation import MotionAugmentation | |
| class ImageMatteDataset(Dataset): | |
| def __init__(self, | |
| imagematte_dir, | |
| background_image_dir, | |
| background_video_dir, | |
| size, | |
| seq_length, | |
| seq_sampler, | |
| transform): | |
| self.imagematte_dir = imagematte_dir | |
| self.imagematte_files = os.listdir(os.path.join(imagematte_dir, 'fgr')) | |
| self.background_image_dir = background_image_dir | |
| self.background_image_files = os.listdir(background_image_dir) | |
| self.background_video_dir = background_video_dir | |
| self.background_video_clips = os.listdir(background_video_dir) | |
| self.background_video_frames = [sorted(os.listdir(os.path.join(background_video_dir, clip))) | |
| for clip in self.background_video_clips] | |
| self.seq_length = seq_length | |
| self.seq_sampler = seq_sampler | |
| self.size = size | |
| self.transform = transform | |
| def __len__(self): | |
| return max(len(self.imagematte_files), len(self.background_image_files) + len(self.background_video_clips)) | |
| def __getitem__(self, idx): | |
| if random.random() < 0.5: | |
| bgrs = self._get_random_image_background() | |
| else: | |
| bgrs = self._get_random_video_background() | |
| fgrs, phas = self._get_imagematte(idx) | |
| if self.transform is not None: | |
| return self.transform(fgrs, phas, bgrs) | |
| return fgrs, phas, bgrs | |
| def _get_imagematte(self, idx): | |
| with Image.open(os.path.join(self.imagematte_dir, 'fgr', self.imagematte_files[idx % len(self.imagematte_files)])) as fgr, \ | |
| Image.open(os.path.join(self.imagematte_dir, 'pha', self.imagematte_files[idx % len(self.imagematte_files)])) as pha: | |
| fgr = self._downsample_if_needed(fgr.convert('RGB')) | |
| pha = self._downsample_if_needed(pha.convert('L')) | |
| fgrs = [fgr] * self.seq_length | |
| phas = [pha] * self.seq_length | |
| return fgrs, phas | |
| def _get_random_image_background(self): | |
| with Image.open(os.path.join(self.background_image_dir, self.background_image_files[random.choice(range(len(self.background_image_files)))])) as bgr: | |
| bgr = self._downsample_if_needed(bgr.convert('RGB')) | |
| bgrs = [bgr] * self.seq_length | |
| return bgrs | |
| def _get_random_video_background(self): | |
| clip_idx = random.choice(range(len(self.background_video_clips))) | |
| frame_count = len(self.background_video_frames[clip_idx]) | |
| frame_idx = random.choice(range(max(1, frame_count - self.seq_length))) | |
| clip = self.background_video_clips[clip_idx] | |
| bgrs = [] | |
| for i in self.seq_sampler(self.seq_length): | |
| frame_idx_t = frame_idx + i | |
| frame = self.background_video_frames[clip_idx][frame_idx_t % frame_count] | |
| with Image.open(os.path.join(self.background_video_dir, clip, frame)) as bgr: | |
| bgr = self._downsample_if_needed(bgr.convert('RGB')) | |
| bgrs.append(bgr) | |
| return bgrs | |
| def _downsample_if_needed(self, img): | |
| w, h = img.size | |
| if min(w, h) > self.size: | |
| scale = self.size / min(w, h) | |
| w = int(scale * w) | |
| h = int(scale * h) | |
| img = img.resize((w, h)) | |
| return img | |
| class ImageMatteAugmentation(MotionAugmentation): | |
| def __init__(self, size): | |
| super().__init__( | |
| size=size, | |
| prob_fgr_affine=0.95, | |
| prob_bgr_affine=0.3, | |
| prob_noise=0.05, | |
| prob_color_jitter=0.3, | |
| prob_grayscale=0.03, | |
| prob_sharpness=0.05, | |
| prob_blur=0.02, | |
| prob_hflip=0.5, | |
| prob_pause=0.03, | |
| ) | |