Spaces:
Build error
Build error
| r""" Superclass for semantic correspondence datasets """ | |
| import os | |
| from torch.utils.data import Dataset | |
| from torchvision import transforms | |
| from PIL import Image | |
| import torch | |
| from model.base.geometry import Geometry | |
| class CorrespondenceDataset(Dataset): | |
| r""" Parent class of PFPascal, PFWillow, and SPair """ | |
| def __init__(self, benchmark, datapath, thres, split): | |
| r""" CorrespondenceDataset constructor """ | |
| super(CorrespondenceDataset, self).__init__() | |
| # {Directory name, Layout path, Image path, Annotation path, PCK threshold} | |
| self.metadata = { | |
| 'pfwillow': ('PF-WILLOW', | |
| 'test_pairs.csv', | |
| '', | |
| '', | |
| 'bbox'), | |
| 'pfpascal': ('PF-PASCAL', | |
| '_pairs.csv', | |
| 'JPEGImages', | |
| 'Annotations', | |
| 'img'), | |
| 'spair': ('SPair-71k', | |
| 'Layout/large', | |
| 'JPEGImages', | |
| 'PairAnnotation', | |
| 'bbox') | |
| } | |
| # Directory path for train, val, or test splits | |
| base_path = os.path.join(os.path.abspath(datapath), self.metadata[benchmark][0]) | |
| if benchmark == 'pfpascal': | |
| self.spt_path = os.path.join(base_path, split+'_pairs.csv') | |
| elif benchmark == 'spair': | |
| self.spt_path = os.path.join(base_path, self.metadata[benchmark][1], split+'.txt') | |
| else: | |
| self.spt_path = os.path.join(base_path, self.metadata[benchmark][1]) | |
| # Directory path for images | |
| self.img_path = os.path.join(base_path, self.metadata[benchmark][2]) | |
| # Directory path for annotations | |
| if benchmark == 'spair': | |
| self.ann_path = os.path.join(base_path, self.metadata[benchmark][3], split) | |
| else: | |
| self.ann_path = os.path.join(base_path, self.metadata[benchmark][3]) | |
| # Miscellaneous | |
| self.max_pts = 40 | |
| self.split = split | |
| self.img_size = Geometry.img_size | |
| self.benchmark = benchmark | |
| self.range_ts = torch.arange(self.max_pts) | |
| self.thres = self.metadata[benchmark][4] if thres == 'auto' else thres | |
| self.transform = transforms.Compose([transforms.Resize((self.img_size, self.img_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225])]) | |
| # To get initialized in subclass constructors | |
| self.train_data = [] | |
| self.src_imnames = [] | |
| self.trg_imnames = [] | |
| self.cls = [] | |
| self.cls_ids = [] | |
| self.src_kps = [] | |
| self.trg_kps = [] | |
| def __len__(self): | |
| r""" Returns the number of pairs """ | |
| return len(self.train_data) | |
| def __getitem__(self, idx): | |
| r""" Constructs and return a batch """ | |
| # Image name | |
| batch = dict() | |
| batch['src_imname'] = self.src_imnames[idx] | |
| batch['trg_imname'] = self.trg_imnames[idx] | |
| # Object category | |
| batch['category_id'] = self.cls_ids[idx] | |
| batch['category'] = self.cls[batch['category_id']] | |
| # Image as numpy (original width, original height) | |
| src_pil = self.get_image(self.src_imnames, idx) | |
| trg_pil = self.get_image(self.trg_imnames, idx) | |
| batch['src_imsize'] = src_pil.size | |
| batch['trg_imsize'] = trg_pil.size | |
| # Image as tensor | |
| batch['src_img'] = self.transform(src_pil) | |
| batch['trg_img'] = self.transform(trg_pil) | |
| # Key-points (re-scaled) | |
| batch['src_kps'], num_pts = self.get_points(self.src_kps, idx, src_pil.size) | |
| batch['trg_kps'], _ = self.get_points(self.trg_kps, idx, trg_pil.size) | |
| batch['n_pts'] = torch.tensor(num_pts) | |
| # Total number of pairs in training split | |
| batch['datalen'] = len(self.train_data) | |
| return batch | |
| def get_image(self, imnames, idx): | |
| r""" Reads PIL image from path """ | |
| path = os.path.join(self.img_path, imnames[idx]) | |
| return Image.open(path).convert('RGB') | |
| def get_pckthres(self, batch, imsize): | |
| r""" Computes PCK threshold """ | |
| if self.thres == 'bbox': | |
| bbox = batch['trg_bbox'].clone() | |
| bbox_w = (bbox[2] - bbox[0]) | |
| bbox_h = (bbox[3] - bbox[1]) | |
| pckthres = torch.max(bbox_w, bbox_h) | |
| elif self.thres == 'img': | |
| imsize_t = batch['trg_img'].size() | |
| pckthres = torch.tensor(max(imsize_t[1], imsize_t[2])) | |
| else: | |
| raise Exception('Invalid pck threshold type: %s' % self.thres) | |
| return pckthres.float() | |
| def get_points(self, pts_list, idx, org_imsize): | |
| r""" Returns key-points of an image """ | |
| xy, n_pts = pts_list[idx].size() | |
| pad_pts = torch.zeros((xy, self.max_pts - n_pts)) - 2 | |
| x_crds = pts_list[idx][0] * (self.img_size / org_imsize[0]) | |
| y_crds = pts_list[idx][1] * (self.img_size / org_imsize[1]) | |
| kps = torch.cat([torch.stack([x_crds, y_crds]), pad_pts], dim=1) | |
| return kps, n_pts | |