| import os |
| import random |
| from torch.utils.data import Dataset |
| from PIL import Image |
|
|
| from .augmentation import MotionAugmentation |
|
|
|
|
| class VideoMatteDataset(Dataset): |
| def __init__(self, |
| videomatte_dir, |
| background_image_dir, |
| background_video_dir, |
| size, |
| seq_length, |
| seq_sampler, |
| transform=None): |
| self.background_image_dir = background_image_dir |
| self.background_image_files = os.listdir(background_image_dir) |
| self.background_video_dir = background_video_dir |
| self.background_video_clips = sorted(os.listdir(background_video_dir)) |
| self.background_video_frames = [sorted(os.listdir(os.path.join(background_video_dir, clip))) |
| for clip in self.background_video_clips] |
| |
| self.videomatte_dir = videomatte_dir |
| self.videomatte_clips = sorted(os.listdir(os.path.join(videomatte_dir, 'fgr'))) |
| self.videomatte_frames = [sorted(os.listdir(os.path.join(videomatte_dir, 'fgr', clip))) |
| for clip in self.videomatte_clips] |
| self.videomatte_idx = [(clip_idx, frame_idx) |
| for clip_idx in range(len(self.videomatte_clips)) |
| for frame_idx in range(0, len(self.videomatte_frames[clip_idx]), seq_length)] |
| self.size = size |
| self.seq_length = seq_length |
| self.seq_sampler = seq_sampler |
| self.transform = transform |
|
|
| def __len__(self): |
| return len(self.videomatte_idx) |
| |
| def __getitem__(self, idx): |
| if random.random() < 0.5: |
| bgrs = self._get_random_image_background() |
| else: |
| bgrs = self._get_random_video_background() |
| |
| fgrs, phas = self._get_videomatte(idx) |
| |
| if self.transform is not None: |
| return self.transform(fgrs, phas, bgrs) |
| |
| return fgrs, phas, bgrs |
| |
| def _get_random_image_background(self): |
| with Image.open(os.path.join(self.background_image_dir, random.choice(self.background_image_files))) as bgr: |
| bgr = self._downsample_if_needed(bgr.convert('RGB')) |
| bgrs = [bgr] * self.seq_length |
| return bgrs |
| |
| def _get_random_video_background(self): |
| clip_idx = random.choice(range(len(self.background_video_clips))) |
| frame_count = len(self.background_video_frames[clip_idx]) |
| frame_idx = random.choice(range(max(1, frame_count - self.seq_length))) |
| clip = self.background_video_clips[clip_idx] |
| bgrs = [] |
| for i in self.seq_sampler(self.seq_length): |
| frame_idx_t = frame_idx + i |
| frame = self.background_video_frames[clip_idx][frame_idx_t % frame_count] |
| with Image.open(os.path.join(self.background_video_dir, clip, frame)) as bgr: |
| bgr = self._downsample_if_needed(bgr.convert('RGB')) |
| bgrs.append(bgr) |
| return bgrs |
| |
| def _get_videomatte(self, idx): |
| clip_idx, frame_idx = self.videomatte_idx[idx] |
| clip = self.videomatte_clips[clip_idx] |
| frame_count = len(self.videomatte_frames[clip_idx]) |
| fgrs, phas = [], [] |
| for i in self.seq_sampler(self.seq_length): |
| frame = self.videomatte_frames[clip_idx][(frame_idx + i) % frame_count] |
| with Image.open(os.path.join(self.videomatte_dir, 'fgr', clip, frame)) as fgr, \ |
| Image.open(os.path.join(self.videomatte_dir, 'pha', clip, frame)) as pha: |
| fgr = self._downsample_if_needed(fgr.convert('RGB')) |
| pha = self._downsample_if_needed(pha.convert('L')) |
| fgrs.append(fgr) |
| phas.append(pha) |
| return fgrs, phas |
| |
| def _downsample_if_needed(self, img): |
| w, h = img.size |
| if min(w, h) > self.size: |
| scale = self.size / min(w, h) |
| w = int(scale * w) |
| h = int(scale * h) |
| img = img.resize((w, h)) |
| return img |
|
|
| class VideoMatteTrainAugmentation(MotionAugmentation): |
| def __init__(self, size): |
| super().__init__( |
| size=size, |
| prob_fgr_affine=0.3, |
| prob_bgr_affine=0.3, |
| prob_noise=0.1, |
| prob_color_jitter=0.3, |
| prob_grayscale=0.02, |
| prob_sharpness=0.1, |
| prob_blur=0.02, |
| prob_hflip=0.5, |
| prob_pause=0.03, |
| ) |
|
|
| class VideoMatteValidAugmentation(MotionAugmentation): |
| def __init__(self, size): |
| super().__init__( |
| size=size, |
| prob_fgr_affine=0, |
| prob_bgr_affine=0, |
| prob_noise=0, |
| prob_color_jitter=0, |
| prob_grayscale=0, |
| prob_sharpness=0, |
| prob_blur=0, |
| prob_hflip=0, |
| prob_pause=0, |
| ) |
|
|