Spaces:
Build error
Build error
| # partially from https://github.com/chenhsuanlin/signed-distance-SRN | |
| import numpy as np | |
| import torch | |
| class Pose(): | |
| # a pose class with util methods | |
| def __call__(self, R=None, t=None): | |
| assert(R is not None or t is not None) | |
| if R is None: | |
| if not isinstance(t, torch.Tensor): t = torch.tensor(t) | |
| R = torch.eye(3, device=t.device).repeat(*t.shape[:-1], 1, 1) | |
| elif t is None: | |
| if not isinstance(R, torch.Tensor): R = torch.tensor(R) | |
| t = torch.zeros(R.shape[:-1], device=R.device) | |
| else: | |
| if not isinstance(R, torch.Tensor): R = torch.tensor(R) | |
| if not isinstance(t, torch.Tensor): t = torch.tensor(t) | |
| assert(R.shape[:-1]==t.shape and R.shape[-2:]==(3, 3)) | |
| R = R.float() | |
| t = t.float() | |
| pose = torch.cat([R, t[..., None]], dim=-1) # [..., 3, 4] | |
| assert(pose.shape[-2:]==(3, 4)) | |
| return pose | |
| def invert(self, pose, use_inverse=False): | |
| R, t = pose[..., :3], pose[..., 3:] | |
| R_inv = R.inverse() if use_inverse else R.transpose(-1, -2) | |
| t_inv = (-R_inv@t)[..., 0] | |
| pose_inv = self(R=R_inv, t=t_inv) | |
| return pose_inv | |
| def compose(self, pose_list): | |
| # pose_new(x) = poseN(...(pose2(pose1(x)))...) | |
| pose_new = pose_list[0] | |
| for pose in pose_list[1:]: | |
| pose_new = self.compose_pair(pose_new, pose) | |
| return pose_new | |
| def compose_pair(self, pose_a, pose_b): | |
| # pose_new(x) = pose_b(pose_a(x)) | |
| R_a, t_a = pose_a[..., :3], pose_a[..., 3:] | |
| R_b, t_b = pose_b[..., :3], pose_b[..., 3:] | |
| R_new = R_b@R_a | |
| t_new = (R_b@t_a+t_b)[..., 0] | |
| pose_new = self(R=R_new, t=t_new) | |
| return pose_new | |
| pose = Pose() | |
| # unit sphere normalization | |
| def valid_norm_fac(seen_points, mask): | |
| ''' | |
| seen_points: [B, H*W, 3] | |
| mask: [B, 1, H, W], boolean | |
| ''' | |
| # get valid points | |
| batch_size = seen_points.shape[0] | |
| # [B, H*W] | |
| mask = mask.view(batch_size, seen_points.shape[1]) | |
| # get mean and variance by sample | |
| means, max_dists = [], [] | |
| for b in range(batch_size): | |
| # [N_valid, 3] | |
| seen_points_valid = seen_points[b][mask[b]] | |
| # [3] | |
| xyz_mean = torch.mean(seen_points_valid, dim=0) | |
| seen_points_valid_zmean = seen_points_valid - xyz_mean | |
| # scalar | |
| max_dist = torch.max(seen_points_valid_zmean.norm(dim=1)) | |
| means.append(xyz_mean) | |
| max_dists.append(max_dist) | |
| # [B, 3] | |
| means = torch.stack(means, dim=0) | |
| # [B] | |
| max_dists = torch.stack(max_dists, dim=0) | |
| return means, max_dists | |
| def get_pixel_grid(opt, H, W): | |
| y_range = torch.arange(H, dtype=torch.float32).to(opt.device) | |
| x_range = torch.arange(W, dtype=torch.float32).to(opt.device) | |
| Y, X = torch.meshgrid(y_range, x_range, indexing='ij') | |
| Z = torch.ones_like(Y) | |
| xyz_grid = torch.stack([X, Y, Z],dim=-1).view(-1,3) | |
| return xyz_grid | |
| def unproj_depth(opt, depth, intr): | |
| ''' | |
| depth: [B, 1, H, W] | |
| intr: [B, 3, 3] | |
| ''' | |
| batch_size, _, H, W = depth.shape | |
| assert opt.H == H == W | |
| depth = depth.squeeze(1) | |
| # [B, 3, 3] | |
| K_inv = torch.linalg.inv(intr).float() | |
| # [1, H*W,3] | |
| pixel_grid = get_pixel_grid(opt, H, W).unsqueeze(0) | |
| # [B, H*W,3] | |
| pixel_grid = pixel_grid.repeat(batch_size, 1, 1) | |
| # [B, 3, H*W] | |
| ray_dirs = K_inv @ pixel_grid.permute(0, 2, 1).contiguous() | |
| # [B, H*W, 3], in camera coordinates | |
| seen_points = ray_dirs.permute(0, 2, 1).contiguous() * depth.view(batch_size, H*W, 1) | |
| return seen_points | |
| def to_hom(X): | |
| ''' | |
| X: [B, N, 3] | |
| Returns: | |
| X_hom: [B, N, 4] | |
| ''' | |
| X_hom = torch.cat([X, torch.ones_like(X[..., :1])], dim=-1) | |
| return X_hom | |
| def world2cam(X_world, pose): | |
| ''' | |
| X_world: [B, N, 3] | |
| pose: [B, 3, 4] | |
| Returns: | |
| X_cam: [B, N, 3] | |
| ''' | |
| X_hom = to_hom(X_world) | |
| X_cam = X_hom @ pose.transpose(-1, -2) | |
| return X_cam | |
| def cam2img(X_cam, cam_intr): | |
| ''' | |
| X_cam: [B, N, 3] | |
| cam_intr: [B, 3, 3] | |
| Returns: | |
| X_img: [B, N, 3] | |
| ''' | |
| X_img = X_cam @ cam_intr.transpose(-1, -2) | |
| return X_img | |
| def proj_points(opt, points, intr, pose): | |
| ''' | |
| points: [B, N, 3] | |
| intr: [B, 3, 3] | |
| pose: [B, 3, 4] | |
| ''' | |
| # [B, N, 3] | |
| points_cam = world2cam(points, pose) | |
| # [B, N] | |
| depth = points_cam[..., 2] | |
| # [B, N, 3] | |
| points_img = cam2img(points_cam, intr) | |
| # [B, N, 2] | |
| points_2D = points_img[..., :2] / points_img[..., 2:] | |
| return points_2D, depth | |
| def azim_to_rotation_matrix(azim, representation='angle'): | |
| """Azim is angle with vector +X, rotated in XZ plane""" | |
| if representation == 'rad': | |
| # [B, ] | |
| cos, sin = torch.cos(azim), torch.sin(azim) | |
| elif representation == 'angle': | |
| # [B, ] | |
| azim = azim * np.pi / 180 | |
| cos, sin = torch.cos(azim), torch.sin(azim) | |
| elif representation == 'trig': | |
| # [B, 2] | |
| cos, sin = azim[:, 0], azim[:, 1] | |
| R = torch.eye(3, device=azim.device)[None].repeat(len(azim), 1, 1) | |
| zeros = torch.zeros(len(azim), device=azim.device) | |
| R[:, 0, :] = torch.stack([cos, zeros, sin], dim=-1) | |
| R[:, 2, :] = torch.stack([-sin, zeros, cos], dim=-1) | |
| return R | |
| def elev_to_rotation_matrix(elev, representation='angle'): | |
| """Angle with vector +Z in YZ plane""" | |
| if representation == 'rad': | |
| # [B, ] | |
| cos, sin = torch.cos(elev), torch.sin(elev) | |
| elif representation == 'angle': | |
| # [B, ] | |
| elev = elev * np.pi / 180 | |
| cos, sin = torch.cos(elev), torch.sin(elev) | |
| elif representation == 'trig': | |
| # [B, 2] | |
| cos, sin = elev[:, 0], elev[:, 1] | |
| R = torch.eye(3, device=elev.device)[None].repeat(len(elev), 1, 1) | |
| R[:, 1, 1:] = torch.stack([cos, -sin], dim=-1) | |
| R[:, 2, 1:] = torch.stack([sin, cos], dim=-1) | |
| return R | |
| def roll_to_rotation_matrix(roll, representation='angle'): | |
| """Angle with vector +X in XY plane""" | |
| if representation == 'rad': | |
| # [B, ] | |
| cos, sin = torch.cos(roll), torch.sin(roll) | |
| elif representation == 'angle': | |
| # [B, ] | |
| roll = roll * np.pi / 180 | |
| cos, sin = torch.cos(roll), torch.sin(roll) | |
| elif representation == 'trig': | |
| # [B, 2] | |
| cos, sin = roll[:, 0], roll[:, 1] | |
| R = torch.eye(3, device=roll.device)[None].repeat(len(roll), 1, 1) | |
| R[:, 0, :2] = torch.stack([cos, sin], dim=-1) | |
| R[:, 1, :2] = torch.stack([-sin, cos], dim=-1) | |
| return R | |
| def get_rotation_sphere(azim_sample=4, elev_sample=4, roll_sample=4, scales=[1.0], device='cuda'): | |
| rotations = [] | |
| azim_range = [0, 360] | |
| elev_range = [0, 360] | |
| roll_range = [0, 360] | |
| azims = np.linspace(azim_range[0], azim_range[1], num=azim_sample, endpoint=False) | |
| elevs = np.linspace(elev_range[0], elev_range[1], num=elev_sample, endpoint=False) | |
| rolls = np.linspace(roll_range[0], roll_range[1], num=roll_sample, endpoint=False) | |
| for scale in scales: | |
| for azim in azims: | |
| for elev in elevs: | |
| for roll in rolls: | |
| Ry = azim_to_rotation_matrix(torch.tensor([azim])) | |
| Rx = elev_to_rotation_matrix(torch.tensor([elev])) | |
| Rz = roll_to_rotation_matrix(torch.tensor([roll])) | |
| R_permute = torch.tensor([ | |
| [-1, 0, 0], | |
| [0, 0, -1], | |
| [0, -1, 0] | |
| ]).float().to(Ry.device).unsqueeze(0).expand_as(Ry) | |
| R = scale * Rz@Rx@Ry@R_permute | |
| rotations.append(R.to(device).float()) | |
| return torch.cat(rotations, dim=0) |