| import torch | |
| import numpy as np | |
| import trimesh | |
| from pytorch360convert import e2c, c2e | |
| def erp_to_cubemap(erp_tensor, face_w = 768, cube_format = "stack", mode = "bilinear", **kwargs): | |
| return e2c(erp_tensor, face_w=face_w, cube_format=cube_format, mode=mode, **kwargs) | |
| def cubemap_to_erp(cube_tensor, erp_h = 1024, erp_w = 2048, cube_format = "stack", mode = "bilinear", **kwargs): | |
| return c2e(cube_tensor, h=erp_h, w=erp_w, cube_format=cube_format, mode=mode, **kwargs) | |
| def roll_augment(data, shift_x): | |
| if data.ndim == 2: | |
| data = data[:, :, np.newaxis] | |
| originally_2d = True | |
| else: | |
| originally_2d = False | |
| if data.ndim == 3 and data.shape[0] != 3: | |
| data = np.moveaxis(data, -1, 0) | |
| moved_axis = True | |
| else: | |
| moved_axis = False | |
| data_rolled = np.roll(data, int(shift_x), axis=2) | |
| if moved_axis: | |
| data_rolled = np.moveaxis(data_rolled, 0, -1) | |
| if originally_2d: | |
| data_rolled = data_rolled[:, :, 0] | |
| return data_rolled | |
| def roll_normal(normal, shift_x): | |
| if normal.ndim == 2: | |
| normal = normal[:, :, np.newaxis] | |
| originally_2d = True | |
| else: | |
| originally_2d = False | |
| if normal.ndim == 3 and normal.shape[0] != 3: | |
| normal = np.moveaxis(normal, -1, 0) | |
| moved_axis = True | |
| else: | |
| moved_axis = False | |
| _, H, W = normal.shape | |
| angle = - 2.0 * np.pi * (shift_x / float(W)) | |
| cos_a, sin_a = np.cos(angle), np.sin(angle) | |
| R = np.array([ | |
| [ cos_a, 0.0, -sin_a], | |
| [ 0.0, 1.0, 0.0 ], | |
| [ sin_a, 0.0, cos_a] | |
| ], dtype=normal.dtype) | |
| n_flat = normal.reshape(3, -1) | |
| normal = (R @ n_flat).reshape(3, H, W) | |
| if moved_axis: | |
| normal = np.moveaxis(normal, 0, -1) | |
| if originally_2d: | |
| normal = normal[:, :, 0] | |
| return normal | |
| def compute_scale_and_shift(pred_g, targ_g, mask_g = None, eps = 0.0, fit_shift = True): | |
| if mask_g is None: | |
| mask_g = torch.ones_like(pred_g, dtype=torch.bool) | |
| if pred_g.shape[0] == 6: | |
| pred_g = pred_g.view(1, 6, pred_g.shape[2], pred_g.shape[3]) | |
| targ_g = targ_g.view(1, 6, targ_g.shape[2], targ_g.shape[3]) | |
| mask_g = mask_g.view(1, 6, mask_g.shape[2], mask_g.shape[3]) | |
| elif pred_g.shape[0] == 1 and pred_g.dim() == 3: | |
| pred_g = pred_g.unsqueeze(0) | |
| targ_g = targ_g.unsqueeze(0) | |
| mask_g = mask_g.unsqueeze(0) | |
| mask_g = mask_g.to(dtype=pred_g.dtype) | |
| a_00 = torch.sum(mask_g * pred_g * pred_g, dim=(1, 2, 3)) | |
| a_01 = torch.sum(mask_g * pred_g, dim=(1, 2, 3)) | |
| a_11 = torch.sum(mask_g, dim=(1, 2, 3)) | |
| b_0 = torch.sum(mask_g * pred_g * targ_g, dim=(1, 2, 3)) | |
| b_1 = torch.sum(mask_g * targ_g, dim=(1, 2, 3)) | |
| if fit_shift: | |
| det = a_00 * a_11 - a_01 * a_01 | |
| det = det + eps | |
| scale = torch.zeros_like(b_0) | |
| shift = torch.zeros_like(b_1) | |
| valid = det > 0 | |
| scale[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid] | |
| shift[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid] | |
| return scale, shift | |
| else: | |
| denom = a_00 + eps | |
| scale = b_0 / denom | |
| shift = torch.zeros_like(scale) | |
| return scale, shift | |
| def compute_shift(pred, targ, mask, eps = 1e-6): | |
| if pred.shape[0] == 6: | |
| pred = pred.view(1, 6, *pred.shape[2:]) | |
| targ = targ.view(1, 6, *targ.shape[2:]) | |
| mask = mask.view(1, 6, *mask.shape[2:]) | |
| w = mask.float() | |
| num = torch.sum(w * (targ - pred), dim=(1,2,3)) | |
| den = torch.sum(w, dim=(1,2,3)).clamp_min(eps) | |
| beta = num / den | |
| return beta | |
| def get_positional_encoding(H, W, pixel_center = True, hw = 96): | |
| jj = np.arange(W, dtype=np.float64) | |
| ii = np.arange(H, dtype=np.float64) | |
| if pixel_center: | |
| jj = jj + 0.5 | |
| ii = ii + 0.5 | |
| U = (jj / W) * 2.0 - 1.0 | |
| V = (ii / H) * 2.0 - 1.0 | |
| U, V = np.meshgrid(U, V, indexing='xy') | |
| erp = np.stack([U, V], axis=-1) | |
| erp_tensor = torch.from_numpy(erp).permute(2, 0, 1).float() | |
| faces = erp_to_cubemap(erp_tensor, face_w=hw) | |
| return faces | |
| def unit_normals(n, eps = 1e-6): | |
| assert n.dim() >= 3 and n.size(-3) == 3, "normals must have channel=3 at dim -3" | |
| denom = torch.clamp(torch.linalg.norm(n, dim=-3, keepdim=True), min=eps) | |
| return n / denom | |
| def _erp_dirs(H, W, device=None, dtype=None): | |
| u = (torch.arange(W, device=device, dtype=dtype) + 0.5) / W | |
| v = (torch.arange(H, device=device, dtype=dtype) + 0.5) / H | |
| theta = u * (2.0 * torch.pi) - torch.pi | |
| phi = (0.5 - v) * torch.pi | |
| theta = theta.view(1, W).expand(H, W) | |
| phi = phi.view(H, 1).expand(H, W) | |
| cosphi = torch.cos(phi) | |
| sinphi = torch.sin(phi) | |
| costhe = torch.cos(theta) | |
| sinthe = torch.sin(theta) | |
| x = cosphi * costhe | |
| y = sinphi | |
| z = -cosphi * sinthe | |
| dirs = torch.stack([x, y, z], dim=0) | |
| return dirs | |
| def depth_to_normals_erp(depth, eps = 1e-6): | |
| assert depth.dim() == 3 and depth.size(0) == 1, "depth must be (B,1,H,W)" | |
| _, H, W = depth.shape | |
| device, dtype = depth.device, depth.dtype | |
| dirs = _erp_dirs(H, W, device=device, dtype=dtype) | |
| P = depth * dirs | |
| dtheta = 2.0 * torch.pi / W | |
| dphi = torch.pi / H | |
| P_l = torch.roll(P, shifts=+1, dims=-1) | |
| P_r = torch.roll(P, shifts=-1, dims=-1) | |
| dP_dtheta = (P_r - P_l) / (2.0 * dtheta) | |
| P_u = torch.cat([P[:, :1, :], P[:, :-1, :]], dim=-2) | |
| P_d = torch.cat([P[:, 1:, :], P[:, -1:, :]], dim=-2) | |
| dP_dphi = (P_d - P_u) / (2.0 * dphi) | |
| n = torch.cross(dP_dtheta, dP_dphi, dim=0) | |
| n = unit_normals(n, eps=eps) | |
| return n | |
| def compute_edge_mask(depth, abs_thresh = 0.1, rel_thresh = 0.1): | |
| assert depth.ndim == 2 | |
| depth = depth.astype(np.float32, copy=False) | |
| valid = depth > 0 | |
| eps = 1e-6 | |
| edge = np.zeros_like(valid, dtype=bool) | |
| d1 = depth[:, :-1] | |
| d2 = depth[:, 1:] | |
| v_pair = valid[:, :-1] & valid[:, 1:] | |
| diff = np.abs(d1 - d2) | |
| rel = diff / (np.minimum(d1, d2) + eps) | |
| edge_pair = v_pair & (diff > abs_thresh) & (rel > rel_thresh) | |
| edge[:, :-1] |= edge_pair | |
| edge[:, 1:] |= edge_pair | |
| d1 = depth[:-1, :] | |
| d2 = depth[1:, :] | |
| v_pair = valid[:-1, :] & valid[1:, :] | |
| diff = np.abs(d1 - d2) | |
| rel = diff / (np.minimum(d1, d2) + eps) | |
| edge_pair = v_pair & (diff > abs_thresh) & (rel > rel_thresh) | |
| edge[:-1, :] |= edge_pair | |
| edge[1:, :] |= edge_pair | |
| keep = valid & (~edge) | |
| return keep | |
| def erp_to_pointcloud(rgb, depth, mask = None): | |
| assert rgb.ndim == 3 and rgb.shape[-1] == 3, "rgb must be (H, W, 3)" | |
| assert depth.ndim == 2 and depth.shape[:2] == rgb.shape[:2], "depth must be (H, W) and match rgb H,W" | |
| H, W, _ = rgb.shape | |
| depth = depth.astype(np.float32, copy=False) | |
| u = (np.arange(W, dtype=np.float32) + 0.5) / W | |
| v = (np.arange(H, dtype=np.float32) + 0.5) / H | |
| theta = u * (2.0 * np.pi) - np.pi | |
| phi = (1 - v) * np.pi - (np.pi / 2.0) | |
| theta, phi = np.meshgrid(theta, phi, indexing="xy") | |
| cos_phi = np.cos(phi) | |
| dir_x = cos_phi * np.cos(theta) | |
| dir_y = np.sin(phi) | |
| dir_z = cos_phi * np.sin(theta) | |
| X = depth * dir_x | |
| Y = depth * dir_y | |
| Z = depth * dir_z | |
| if mask is None: | |
| keep = depth > 0 | |
| else: | |
| keep = (mask.astype(bool)) & (depth > 0) | |
| points = np.stack([X, Y, Z], axis=-1)[keep] | |
| rgb_clamped = np.clip(rgb, -1.0, 1.0) | |
| colors = ((rgb_clamped * 0.5 + 0.5) * 255.0).astype(np.uint8) | |
| colors = colors.reshape(H, W, 3)[keep] | |
| return points.astype(np.float32, copy=False), colors | |
| def erp_to_point_cloud_glb(rgb, depth, mask=None, export_path=None): | |
| points, colors = erp_to_pointcloud(rgb, depth, mask) | |
| scene = trimesh.Scene() | |
| scene.add_geometry(trimesh.PointCloud(vertices=points, colors=colors)) | |
| scene.export(export_path) | |
| return scene | |