File size: 3,552 Bytes
0e83290 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import os
from os import path
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from torchvision.transforms import InterpolationMode
import torch.nn.functional as F
from PIL import Image
import numpy as np
from dataset.range_transform import im_normalization
class VideoReader(Dataset):
"""
This class is used to read a video, one frame at a time
"""
def __init__(self, vid_name, image_dir, mask_dir, size=-1, to_save=None, use_all_mask=False, size_dir=None):
"""
image_dir - points to a directory of jpg images
mask_dir - points to a directory of png masks
size - resize min. side to size. Does nothing if <0.
to_save - optionally contains a list of file names without extensions
where the segmentation mask is required
use_all_mask - when true, read all available mask in mask_dir.
Default false. Set to true for YouTubeVOS validation.
"""
self.vid_name = vid_name
self.image_dir = image_dir
self.mask_dir = mask_dir
self.to_save = to_save
self.use_all_mask = use_all_mask
if size_dir is None:
self.size_dir = self.image_dir
else:
self.size_dir = size_dir
self.frames = sorted(os.listdir(self.image_dir))
self.palette = Image.open(path.join(mask_dir, sorted(os.listdir(mask_dir))[0])).getpalette()
self.first_gt_path = path.join(self.mask_dir, sorted(os.listdir(self.mask_dir))[0])
if size < 0:
self.im_transform = transforms.Compose([
transforms.ToTensor(),
im_normalization,
])
else:
self.im_transform = transforms.Compose([
transforms.ToTensor(),
im_normalization,
transforms.Resize(size, interpolation=InterpolationMode.BILINEAR),
])
self.size = size
def __getitem__(self, idx):
frame = self.frames[idx]
info = {}
data = {}
info['frame'] = frame
info['save'] = (self.to_save is None) or (frame[:-4] in self.to_save)
im_path = path.join(self.image_dir, frame)
img = Image.open(im_path).convert('RGB')
if self.image_dir == self.size_dir:
shape = np.array(img).shape[:2]
else:
size_path = path.join(self.size_dir, frame)
size_im = Image.open(size_path).convert('RGB')
shape = np.array(size_im).shape[:2]
gt_path = path.join(self.mask_dir, frame[:-4]+'.png')
img = self.im_transform(img)
load_mask = self.use_all_mask or (gt_path == self.first_gt_path)
if load_mask and path.exists(gt_path):
mask = Image.open(gt_path).convert('P')
mask = np.array(mask, dtype=np.uint8)
data['mask'] = mask
info['shape'] = shape
info['need_resize'] = not (self.size < 0)
data['rgb'] = img
data['info'] = info
return data
def resize_mask(self, mask):
# mask transform is applied AFTER mapper, so we need to post-process it in eval.py
h, w = mask.shape[-2:]
min_hw = min(h, w)
return F.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)),
mode='nearest')
def get_palette(self):
return self.palette
def __len__(self):
return len(self.frames) |