zoo3d
/
MaskClustering
/third_party
/Entity
/High-Quality-Segmention
/dataset
/offline_dataset_crm.py
| import os | |
| from os import path | |
| from torch.utils.data.dataset import Dataset | |
| from torchvision import transforms, utils | |
| from torchvision.transforms import functional | |
| from PIL import Image | |
| import numpy as np | |
| import progressbar | |
| from dataset.make_bb_trans import * | |
| import torch | |
| def make_coord(shape, ranges=None, flatten=True): | |
| """ Make coordinates at grid centers. | |
| """ | |
| coord_seqs = [] | |
| for i, n in enumerate(shape): | |
| if ranges is None: | |
| v0, v1 = -1, 1 | |
| else: | |
| v0, v1 = ranges[i] | |
| r = (v1 - v0) / (2 * n) | |
| seq = v0 + r + (2 * r) * torch.arange(n).float() | |
| coord_seqs.append(seq) | |
| ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1) | |
| if flatten: | |
| ret = ret.view(-1, ret.shape[-1]) | |
| return ret | |
| def to_pixel_samples(img): | |
| """ Convert the image to coord-RGB pairs. | |
| img: Tensor, (3, H, W) | |
| """ | |
| coord = make_coord(img.shape[-2:]) | |
| rgb = img.view(1, -1).permute(1, 0) | |
| return coord, rgb | |
| def resize_fn(img, size): | |
| return transforms.ToTensor()( | |
| transforms.Resize(size, Image.BICUBIC)( | |
| transforms.ToPILImage()(img))) | |
| class OfflineDataset_crm(Dataset): | |
| def __init__(self, root, in_memory=False, need_name=False, resize=False, do_crop=False): | |
| self.root = root | |
| self.need_name = need_name | |
| self.resize = resize | |
| self.do_crop = do_crop | |
| self.in_memory = in_memory | |
| imgs = os.listdir(root) | |
| imgs = sorted(imgs) | |
| """ | |
| There are three kinds of files: _im.png, _seg.png, _gt.png | |
| """ | |
| im_list = [im for im in imgs if 'im' in im[-7:].lower()] | |
| self.im_list = [path.join(root, im) for im in im_list] | |
| print('%d images found' % len(self.im_list)) | |
| # Make up some transforms | |
| self.im_transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ), | |
| ]) | |
| self.gt_transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| ]) | |
| self.seg_transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.5], | |
| std=[0.5] | |
| ), | |
| ]) | |
| if self.resize: | |
| self.resize_bi = lambda x: x.resize((224, 224), Image.BILINEAR) | |
| self.resize_nr = lambda x: x.resize((224, 224), Image.NEAREST) | |
| else: | |
| self.resize_bi = lambda x: x | |
| self.resize_nr = lambda x: x | |
| if self.in_memory: | |
| print('Loading things into memory') | |
| self.images = [] | |
| self.gts = [] | |
| self.segs = [] | |
| for im in progressbar.progressbar(self.im_list): | |
| image, seg, gt = self.load_tuple(im) | |
| self.images.append(image) | |
| self.segs.append(seg) | |
| self.gts.append(gt) | |
| def load_tuple(self, im): | |
| seg = Image.open(im[:-7]+'_seg.png').convert('L') | |
| crop_lambda = self.get_crop_lambda(seg) | |
| image = self.resize_bi(crop_lambda(Image.open(im).convert('RGB'))) | |
| gt = self.resize_bi(crop_lambda(Image.open(im[:-7]+'_gt.png').convert('L'))) | |
| seg = self.resize_bi(crop_lambda(Image.open(im[:-7]+'_seg.png').convert('L'))) | |
| return image, seg, gt | |
| def get_crop_lambda(self, seg): | |
| if self.do_crop: | |
| seg = np.array(seg) | |
| h, w = seg.shape | |
| try: | |
| bb = get_bb_position(seg) | |
| rmin, rmax, cmin, cmax = scale_bb_by(*bb, h, w, 0.15, 0.15) | |
| return lambda x: functional.crop(x, rmin, cmin, rmax-rmin, cmax-cmin) | |
| except: | |
| return lambda x: x | |
| else: | |
| return lambda x: x | |
| def __getitem__(self, idx): | |
| if self.in_memory: | |
| im = self.images[idx] | |
| gt = self.gts[idx] | |
| seg = self.segs[idx] | |
| else: | |
| im, seg, gt = self.load_tuple(self.im_list[idx]) | |
| im = self.im_transform(im) | |
| gt = self.gt_transform(gt) | |
| seg = self.seg_transform(seg) | |
| hr_coord, hr_rgb = to_pixel_samples(seg.contiguous()) | |
| cell = torch.ones_like(hr_coord) | |
| cell[:, 0] *= 2 / seg.shape[-2] | |
| cell[:, 1] *= 2 / seg.shape[-1] | |
| crop_lr = resize_fn(seg, seg.shape[-2]) # | |
| if self.need_name: | |
| return im, seg, gt, os.path.basename(self.im_list[idx][:-7]), {'coord': hr_coord, 'cell': cell} # 'inp': crop_lr, , 'gt': hr_rgb | |
| else: | |
| return im, seg, gt | |
| def __len__(self): | |
| return len(self.im_list) | |
| if __name__ == '__main__': | |
| o = OfflineDataset('data/val_static') | |