Spaces:
Runtime error
Runtime error
| from torch.utils.data import Dataset | |
| import numpy as np | |
| import os | |
| import random | |
| import torchvision.transforms as transforms | |
| from PIL import Image, ImageOps | |
| import cv2 | |
| import torch | |
| from PIL.ImageFilter import GaussianBlur | |
| import trimesh | |
| import cv2 | |
| class EvalDataset(Dataset): | |
| def modify_commandline_options(parser): | |
| return parser | |
| def __init__(self, opt, root=None): | |
| self.opt = opt | |
| self.projection_mode = 'orthogonal' | |
| # Path setup | |
| self.root = self.opt.dataroot | |
| if root is not None: | |
| self.root = root | |
| self.RENDER = os.path.join(self.root, 'RENDER') | |
| self.MASK = os.path.join(self.root, 'MASK') | |
| self.PARAM = os.path.join(self.root, 'PARAM') | |
| self.OBJ = os.path.join(self.root, 'GEO', 'OBJ') | |
| self.phase = 'val' | |
| self.load_size = self.opt.loadSize | |
| self.num_views = self.opt.num_views | |
| self.max_view_angle = 360 | |
| self.interval = 1 | |
| self.subjects = self.get_subjects() | |
| # PIL to tensor | |
| self.to_tensor = transforms.Compose([ | |
| transforms.Resize(self.load_size), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ]) | |
| def get_subjects(self): | |
| var_file = os.path.join(self.root, 'val.txt') | |
| if os.path.exists(var_file): | |
| var_subjects = np.loadtxt(var_file, dtype=str) | |
| return sorted(list(var_subjects)) | |
| all_subjects = os.listdir(self.RENDER) | |
| return sorted(list(all_subjects)) | |
| def __len__(self): | |
| return len(self.subjects) * self.max_view_angle // self.interval | |
| def get_render(self, subject, num_views, view_id=None, random_sample=False): | |
| ''' | |
| Return the render data | |
| :param subject: subject name | |
| :param num_views: how many views to return | |
| :param view_id: the first view_id. If None, select a random one. | |
| :return: | |
| 'img': [num_views, C, W, H] images | |
| 'calib': [num_views, 4, 4] calibration matrix | |
| 'extrinsic': [num_views, 4, 4] extrinsic matrix | |
| 'mask': [num_views, 1, W, H] masks | |
| ''' | |
| # For now we only have pitch = 00. Hard code it here | |
| pitch = 0 | |
| # Select a random view_id from self.max_view_angle if not given | |
| if view_id is None: | |
| view_id = np.random.randint(self.max_view_angle) | |
| # The ids are an even distribution of num_views around view_id | |
| view_ids = [(view_id + self.max_view_angle // num_views * offset) % self.max_view_angle | |
| for offset in range(num_views)] | |
| if random_sample: | |
| view_ids = np.random.choice(self.max_view_angle, num_views, replace=False) | |
| calib_list = [] | |
| render_list = [] | |
| mask_list = [] | |
| extrinsic_list = [] | |
| for vid in view_ids: | |
| param_path = os.path.join(self.PARAM, subject, '%d_%02d.npy' % (vid, pitch)) | |
| render_path = os.path.join(self.RENDER, subject, '%d_%02d.jpg' % (vid, pitch)) | |
| mask_path = os.path.join(self.MASK, subject, '%d_%02d.png' % (vid, pitch)) | |
| # loading calibration data | |
| param = np.load(param_path) | |
| # pixel unit / world unit | |
| ortho_ratio = param.item().get('ortho_ratio') | |
| # world unit / model unit | |
| scale = param.item().get('scale') | |
| # camera center world coordinate | |
| center = param.item().get('center') | |
| # model rotation | |
| R = param.item().get('R') | |
| translate = -np.matmul(R, center).reshape(3, 1) | |
| extrinsic = np.concatenate([R, translate], axis=1) | |
| extrinsic = np.concatenate([extrinsic, np.array([0, 0, 0, 1]).reshape(1, 4)], 0) | |
| # Match camera space to image pixel space | |
| scale_intrinsic = np.identity(4) | |
| scale_intrinsic[0, 0] = scale / ortho_ratio | |
| scale_intrinsic[1, 1] = -scale / ortho_ratio | |
| scale_intrinsic[2, 2] = -scale / ortho_ratio | |
| # Match image pixel space to image uv space | |
| uv_intrinsic = np.identity(4) | |
| uv_intrinsic[0, 0] = 1.0 / float(self.opt.loadSize // 2) | |
| uv_intrinsic[1, 1] = 1.0 / float(self.opt.loadSize // 2) | |
| uv_intrinsic[2, 2] = 1.0 / float(self.opt.loadSize // 2) | |
| # Transform under image pixel space | |
| trans_intrinsic = np.identity(4) | |
| mask = Image.open(mask_path).convert('L') | |
| render = Image.open(render_path).convert('RGB') | |
| intrinsic = np.matmul(trans_intrinsic, np.matmul(uv_intrinsic, scale_intrinsic)) | |
| calib = torch.Tensor(np.matmul(intrinsic, extrinsic)).float() | |
| extrinsic = torch.Tensor(extrinsic).float() | |
| mask = transforms.Resize(self.load_size)(mask) | |
| mask = transforms.ToTensor()(mask).float() | |
| mask_list.append(mask) | |
| render = self.to_tensor(render) | |
| render = mask.expand_as(render) * render | |
| render_list.append(render) | |
| calib_list.append(calib) | |
| extrinsic_list.append(extrinsic) | |
| return { | |
| 'img': torch.stack(render_list, dim=0), | |
| 'calib': torch.stack(calib_list, dim=0), | |
| 'extrinsic': torch.stack(extrinsic_list, dim=0), | |
| 'mask': torch.stack(mask_list, dim=0) | |
| } | |
| def get_item(self, index): | |
| # In case of a missing file or IO error, switch to a random sample instead | |
| try: | |
| sid = index % len(self.subjects) | |
| vid = (index // len(self.subjects)) * self.interval | |
| # name of the subject 'rp_xxxx_xxx' | |
| subject = self.subjects[sid] | |
| res = { | |
| 'name': subject, | |
| 'mesh_path': os.path.join(self.OBJ, subject + '.obj'), | |
| 'sid': sid, | |
| 'vid': vid, | |
| } | |
| render_data = self.get_render(subject, num_views=self.num_views, view_id=vid, | |
| random_sample=self.opt.random_multiview) | |
| res.update(render_data) | |
| return res | |
| except Exception as e: | |
| print(e) | |
| return self.get_item(index=random.randint(0, self.__len__() - 1)) | |
| def __getitem__(self, index): | |
| return self.get_item(index) | |