File size: 4,788 Bytes
352cafd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import os
from os import path
from torch.utils.data.dataset import Dataset
from torchvision import transforms, utils
from torchvision.transforms import functional
from PIL import Image
import numpy as np
import progressbar
from dataset.make_bb_trans import *
import torch
def make_coord(shape, ranges=None, flatten=True):
""" Make coordinates at grid centers.
"""
coord_seqs = []
for i, n in enumerate(shape):
if ranges is None:
v0, v1 = -1, 1
else:
v0, v1 = ranges[i]
r = (v1 - v0) / (2 * n)
seq = v0 + r + (2 * r) * torch.arange(n).float()
coord_seqs.append(seq)
ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
if flatten:
ret = ret.view(-1, ret.shape[-1])
return ret
def to_pixel_samples(img):
""" Convert the image to coord-RGB pairs.
img: Tensor, (3, H, W)
"""
coord = make_coord(img.shape[-2:])
rgb = img.view(1, -1).permute(1, 0)
return coord, rgb
def resize_fn(img, size):
return transforms.ToTensor()(
transforms.Resize(size, Image.BICUBIC)(
transforms.ToPILImage()(img)))
class OfflineDataset_crm(Dataset):
def __init__(self, root, in_memory=False, need_name=False, resize=False, do_crop=False):
self.root = root
self.need_name = need_name
self.resize = resize
self.do_crop = do_crop
self.in_memory = in_memory
imgs = os.listdir(root)
imgs = sorted(imgs)
"""
There are three kinds of files: _im.png, _seg.png, _gt.png
"""
im_list = [im for im in imgs if 'im' in im[-7:].lower()]
self.im_list = [path.join(root, im) for im in im_list]
print('%d images found' % len(self.im_list))
# Make up some transforms
self.im_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
self.gt_transform = transforms.Compose([
transforms.ToTensor(),
])
self.seg_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=[0.5],
std=[0.5]
),
])
if self.resize:
self.resize_bi = lambda x: x.resize((224, 224), Image.BILINEAR)
self.resize_nr = lambda x: x.resize((224, 224), Image.NEAREST)
else:
self.resize_bi = lambda x: x
self.resize_nr = lambda x: x
if self.in_memory:
print('Loading things into memory')
self.images = []
self.gts = []
self.segs = []
for im in progressbar.progressbar(self.im_list):
image, seg, gt = self.load_tuple(im)
self.images.append(image)
self.segs.append(seg)
self.gts.append(gt)
def load_tuple(self, im):
seg = Image.open(im[:-7]+'_seg.png').convert('L')
crop_lambda = self.get_crop_lambda(seg)
image = self.resize_bi(crop_lambda(Image.open(im).convert('RGB')))
gt = self.resize_bi(crop_lambda(Image.open(im[:-7]+'_gt.png').convert('L')))
seg = self.resize_bi(crop_lambda(Image.open(im[:-7]+'_seg.png').convert('L')))
return image, seg, gt
def get_crop_lambda(self, seg):
if self.do_crop:
seg = np.array(seg)
h, w = seg.shape
try:
bb = get_bb_position(seg)
rmin, rmax, cmin, cmax = scale_bb_by(*bb, h, w, 0.15, 0.15)
return lambda x: functional.crop(x, rmin, cmin, rmax-rmin, cmax-cmin)
except:
return lambda x: x
else:
return lambda x: x
def __getitem__(self, idx):
if self.in_memory:
im = self.images[idx]
gt = self.gts[idx]
seg = self.segs[idx]
else:
im, seg, gt = self.load_tuple(self.im_list[idx])
im = self.im_transform(im)
gt = self.gt_transform(gt)
seg = self.seg_transform(seg)
hr_coord, hr_rgb = to_pixel_samples(seg.contiguous())
cell = torch.ones_like(hr_coord)
cell[:, 0] *= 2 / seg.shape[-2]
cell[:, 1] *= 2 / seg.shape[-1]
crop_lr = resize_fn(seg, seg.shape[-2]) #
if self.need_name:
return im, seg, gt, os.path.basename(self.im_list[idx][:-7]), {'coord': hr_coord, 'cell': cell} # 'inp': crop_lr, , 'gt': hr_rgb
else:
return im, seg, gt
def __len__(self):
return len(self.im_list)
if __name__ == '__main__':
o = OfflineDataset('data/val_static')
|