| import os | |
| import random | |
| from torch.utils.data import Dataset | |
| from PIL import Image | |
| from .augmentation import MotionAugmentation | |
| class VideoMatteDataset(Dataset): | |
| def __init__(self, | |
| videomatte_dir, | |
| background_image_dir, | |
| background_video_dir, | |
| size, | |
| seq_length, | |
| seq_sampler, | |
| transform=None): | |
| self.background_image_dir = background_image_dir | |
| self.background_image_files = os.listdir(background_image_dir) | |
| self.background_video_dir = background_video_dir | |
| self.background_video_clips = sorted(os.listdir(background_video_dir)) | |
| self.background_video_frames = [sorted(os.listdir(os.path.join(background_video_dir, clip))) | |
| for clip in self.background_video_clips] | |
| self.videomatte_dir = videomatte_dir | |
| self.videomatte_clips = sorted(os.listdir(os.path.join(videomatte_dir, 'fgr'))) | |
| self.videomatte_frames = [sorted(os.listdir(os.path.join(videomatte_dir, 'fgr', clip))) | |
| for clip in self.videomatte_clips] | |
| self.videomatte_idx = [(clip_idx, frame_idx) | |
| for clip_idx in range(len(self.videomatte_clips)) | |
| for frame_idx in range(0, len(self.videomatte_frames[clip_idx]), seq_length)] | |
| self.size = size | |
| self.seq_length = seq_length | |
| self.seq_sampler = seq_sampler | |
| self.transform = transform | |
| def __len__(self): | |
| return len(self.videomatte_idx) | |
| def __getitem__(self, idx): | |
| if random.random() < 0.5: | |
| bgrs = self._get_random_image_background() | |
| else: | |
| bgrs = self._get_random_video_background() | |
| fgrs, phas = self._get_videomatte(idx) | |
| if self.transform is not None: | |
| return self.transform(fgrs, phas, bgrs) | |
| return fgrs, phas, bgrs | |
| def _get_random_image_background(self): | |
| with Image.open(os.path.join(self.background_image_dir, random.choice(self.background_image_files))) as bgr: | |
| bgr = self._downsample_if_needed(bgr.convert('RGB')) | |
| bgrs = [bgr] * self.seq_length | |
| return bgrs | |
| def _get_random_video_background(self): | |
| clip_idx = random.choice(range(len(self.background_video_clips))) | |
| frame_count = len(self.background_video_frames[clip_idx]) | |
| frame_idx = random.choice(range(max(1, frame_count - self.seq_length))) | |
| clip = self.background_video_clips[clip_idx] | |
| bgrs = [] | |
| for i in self.seq_sampler(self.seq_length): | |
| frame_idx_t = frame_idx + i | |
| frame = self.background_video_frames[clip_idx][frame_idx_t % frame_count] | |
| with Image.open(os.path.join(self.background_video_dir, clip, frame)) as bgr: | |
| bgr = self._downsample_if_needed(bgr.convert('RGB')) | |
| bgrs.append(bgr) | |
| return bgrs | |
| def _get_videomatte(self, idx): | |
| clip_idx, frame_idx = self.videomatte_idx[idx] | |
| clip = self.videomatte_clips[clip_idx] | |
| frame_count = len(self.videomatte_frames[clip_idx]) | |
| fgrs, phas = [], [] | |
| for i in self.seq_sampler(self.seq_length): | |
| frame = self.videomatte_frames[clip_idx][(frame_idx + i) % frame_count] | |
| with Image.open(os.path.join(self.videomatte_dir, 'fgr', clip, frame)) as fgr, \ | |
| Image.open(os.path.join(self.videomatte_dir, 'pha', clip, frame)) as pha: | |
| fgr = self._downsample_if_needed(fgr.convert('RGB')) | |
| pha = self._downsample_if_needed(pha.convert('L')) | |
| fgrs.append(fgr) | |
| phas.append(pha) | |
| return fgrs, phas | |
| def _downsample_if_needed(self, img): | |
| w, h = img.size | |
| if min(w, h) > self.size: | |
| scale = self.size / min(w, h) | |
| w = int(scale * w) | |
| h = int(scale * h) | |
| img = img.resize((w, h)) | |
| return img | |
| class VideoMatteTrainAugmentation(MotionAugmentation): | |
| def __init__(self, size): | |
| super().__init__( | |
| size=size, | |
| prob_fgr_affine=0.3, | |
| prob_bgr_affine=0.3, | |
| prob_noise=0.1, | |
| prob_color_jitter=0.3, | |
| prob_grayscale=0.02, | |
| prob_sharpness=0.1, | |
| prob_blur=0.02, | |
| prob_hflip=0.5, | |
| prob_pause=0.03, | |
| ) | |
| class VideoMatteValidAugmentation(MotionAugmentation): | |
| def __init__(self, size): | |
| super().__init__( | |
| size=size, | |
| prob_fgr_affine=0, | |
| prob_bgr_affine=0, | |
| prob_noise=0, | |
| prob_color_jitter=0, | |
| prob_grayscale=0, | |
| prob_sharpness=0, | |
| prob_blur=0, | |
| prob_hflip=0, | |
| prob_pause=0, | |
| ) | |