| | import os |
| | import numpy as np |
| | import random |
| | import json |
| | import os |
| | from torch.utils.data import Dataset |
| | from torchvision import transforms |
| | from torchvision.transforms import functional as F |
| | from PIL import Image |
| |
|
| |
|
| | class CocoPanopticDataset(Dataset): |
| | def __init__(self, |
| | imgdir: str, |
| | anndir: str, |
| | annfile: str, |
| | transform=None): |
| | with open(annfile) as f: |
| | self.data = json.load(f)['annotations'] |
| | self.data = list(filter(lambda data: any(info['category_id'] == 1 for info in data['segments_info']), self.data)) |
| | self.imgdir = imgdir |
| | self.anndir = anndir |
| | self.transform = transform |
| | |
| | def __len__(self): |
| | return len(self.data) |
| | |
| | def __getitem__(self, idx): |
| | data = self.data[idx] |
| | img = self._load_img(data) |
| | seg = self._load_seg(data) |
| | |
| | if self.transform is not None: |
| | img, seg = self.transform(img, seg) |
| | |
| | return img, seg |
| |
|
| | def _load_img(self, data): |
| | with Image.open(os.path.join(self.imgdir, data['file_name'].replace('.png', '.jpg'))) as img: |
| | return img.convert('RGB') |
| | |
| | def _load_seg(self, data): |
| | with Image.open(os.path.join(self.anndir, data['file_name'])) as ann: |
| | ann.load() |
| | |
| | ann = np.array(ann, copy=False).astype(np.int32) |
| | ann = ann[:, :, 0] + 256 * ann[:, :, 1] + 256 * 256 * ann[:, :, 2] |
| | seg = np.zeros(ann.shape, np.uint8) |
| | |
| | for segments_info in data['segments_info']: |
| | if segments_info['category_id'] in [1, 27, 32]: |
| | seg[ann == segments_info['id']] = 255 |
| | |
| | return Image.fromarray(seg) |
| | |
| |
|
| | class CocoPanopticTrainAugmentation: |
| | def __init__(self, size): |
| | self.size = size |
| | self.jitter = transforms.ColorJitter(0.1, 0.1, 0.1, 0.1) |
| | |
| | def __call__(self, img, seg): |
| | |
| | params = transforms.RandomAffine.get_params(degrees=(-20, 20), translate=(0.1, 0.1), |
| | scale_ranges=(1, 1), shears=(-10, 10), img_size=img.size) |
| | img = F.affine(img, *params, interpolation=F.InterpolationMode.BILINEAR) |
| | seg = F.affine(seg, *params, interpolation=F.InterpolationMode.NEAREST) |
| | |
| | |
| | params = transforms.RandomResizedCrop.get_params(img, scale=(0.5, 1), ratio=(0.7, 1.3)) |
| | img = F.resized_crop(img, *params, self.size, interpolation=F.InterpolationMode.BILINEAR) |
| | seg = F.resized_crop(seg, *params, self.size, interpolation=F.InterpolationMode.NEAREST) |
| | |
| | |
| | if random.random() < 0.5: |
| | img = F.hflip(img) |
| | seg = F.hflip(seg) |
| | |
| | |
| | img = self.jitter(img) |
| | |
| | |
| | img = F.to_tensor(img) |
| | seg = F.to_tensor(seg) |
| | |
| | return img, seg |
| | |
| |
|
| | class CocoPanopticValidAugmentation: |
| | def __init__(self, size): |
| | self.size = size |
| | |
| | def __call__(self, img, seg): |
| | |
| | params = transforms.RandomResizedCrop.get_params(img, scale=(1, 1), ratio=(1., 1.)) |
| | img = F.resized_crop(img, *params, self.size, interpolation=F.InterpolationMode.BILINEAR) |
| | seg = F.resized_crop(seg, *params, self.size, interpolation=F.InterpolationMode.NEAREST) |
| | |
| | |
| | img = F.to_tensor(img) |
| | seg = F.to_tensor(seg) |
| | |
| | return img, seg |