|
|
import torch |
|
|
import os |
|
|
import json |
|
|
import numpy as np |
|
|
import random |
|
|
from torch.utils.data import Dataset |
|
|
from PIL import Image |
|
|
from torchvision import transforms |
|
|
from torchvision.transforms import functional as F |
|
|
|
|
|
|
|
|
class YouTubeVISDataset(Dataset): |
|
|
def __init__(self, videodir, annfile, size, seq_length, seq_sampler, transform=None): |
|
|
self.videodir = videodir |
|
|
self.size = size |
|
|
self.seq_length = seq_length |
|
|
self.seq_sampler = seq_sampler |
|
|
self.transform = transform |
|
|
|
|
|
with open(annfile) as f: |
|
|
data = json.load(f) |
|
|
|
|
|
self.masks = {} |
|
|
for ann in data['annotations']: |
|
|
if ann['category_id'] == 26: |
|
|
video_id = ann['video_id'] |
|
|
if video_id not in self.masks: |
|
|
self.masks[video_id] = [[] for _ in range(len(ann['segmentations']))] |
|
|
for frame, mask in zip(self.masks[video_id], ann['segmentations']): |
|
|
if mask is not None: |
|
|
frame.append(mask) |
|
|
|
|
|
self.videos = {} |
|
|
for video in data['videos']: |
|
|
video_id = video['id'] |
|
|
if video_id in self.masks: |
|
|
self.videos[video_id] = video |
|
|
|
|
|
self.index = [] |
|
|
for video_id in self.videos.keys(): |
|
|
for frame in range(len(self.videos[video_id]['file_names'])): |
|
|
self.index.append((video_id, frame)) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.index) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
video_id, frame_id = self.index[idx] |
|
|
video = self.videos[video_id] |
|
|
frame_count = len(self.videos[video_id]['file_names']) |
|
|
H, W = video['height'], video['width'] |
|
|
|
|
|
imgs, segs = [], [] |
|
|
for t in self.seq_sampler(self.seq_length): |
|
|
frame = (frame_id + t) % frame_count |
|
|
|
|
|
filename = video['file_names'][frame] |
|
|
masks = self.masks[video_id][frame] |
|
|
|
|
|
with Image.open(os.path.join(self.videodir, filename)) as img: |
|
|
imgs.append(self._downsample_if_needed(img.convert('RGB'), Image.BILINEAR)) |
|
|
|
|
|
seg = np.zeros((H, W), dtype=np.uint8) |
|
|
for mask in masks: |
|
|
seg |= self._decode_rle(mask) |
|
|
segs.append(self._downsample_if_needed(Image.fromarray(seg), Image.NEAREST)) |
|
|
|
|
|
if self.transform is not None: |
|
|
imgs, segs = self.transform(imgs, segs) |
|
|
|
|
|
return imgs, segs |
|
|
|
|
|
def _decode_rle(self, rle): |
|
|
H, W = rle['size'] |
|
|
msk = np.zeros(H * W, dtype=np.uint8) |
|
|
encoding = rle['counts'] |
|
|
skip = 0 |
|
|
for i in range(0, len(encoding) - 1, 2): |
|
|
skip += encoding[i] |
|
|
draw = encoding[i + 1] |
|
|
msk[skip : skip + draw] = 255 |
|
|
skip += draw |
|
|
return msk.reshape(W, H).transpose() |
|
|
|
|
|
def _downsample_if_needed(self, img, resample): |
|
|
w, h = img.size |
|
|
if min(w, h) > self.size: |
|
|
scale = self.size / min(w, h) |
|
|
w = int(scale * w) |
|
|
h = int(scale * h) |
|
|
img = img.resize((w, h), resample) |
|
|
return img |
|
|
|
|
|
|
|
|
class YouTubeVISAugmentation: |
|
|
def __init__(self, size): |
|
|
self.size = size |
|
|
self.jitter = transforms.ColorJitter(0.3, 0.3, 0.3, 0.15) |
|
|
|
|
|
def __call__(self, imgs, segs): |
|
|
|
|
|
|
|
|
imgs = torch.stack([F.to_tensor(img) for img in imgs]) |
|
|
segs = torch.stack([F.to_tensor(seg) for seg in segs]) |
|
|
|
|
|
|
|
|
params = transforms.RandomResizedCrop.get_params(imgs, scale=(0.8, 1), ratio=(0.9, 1.1)) |
|
|
imgs = F.resized_crop(imgs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR) |
|
|
segs = F.resized_crop(segs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR) |
|
|
|
|
|
|
|
|
imgs = self.jitter(imgs) |
|
|
|
|
|
|
|
|
if random.random() < 0.05: |
|
|
imgs = F.rgb_to_grayscale(imgs, num_output_channels=3) |
|
|
|
|
|
|
|
|
if random.random() < 0.5: |
|
|
imgs = F.hflip(imgs) |
|
|
segs = F.hflip(segs) |
|
|
|
|
|
return imgs, segs |
|
|
|