# SPDX-FileCopyrightText: 2026 Ze-Xin Yin, Robot labs of Horizon Robotics, and D-Robotics # SPDX-License-Identifier: Apache-2.0 # See the LICENSE file in the project root for full license information. import os import json import cv2 import torch from PIL import Image import imageio import numpy as np import open3d as o3d from einops import rearrange def voxelize_mesh(points, faces, clip_range_first=False, return_mask=True, resolution=64): if clip_range_first: points = np.clip(points, -0.5 + 1e-6, 0.5 - 1e-6) mesh = o3d.geometry.TriangleMesh() mesh.vertices = o3d.utility.Vector3dVector(points) if isinstance(faces, o3d.cuda.pybind.utility.Vector3iVector): mesh.triangles = faces else: mesh.triangles = o3d.cuda.pybind.utility.Vector3iVector(faces) voxel_grid = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds(mesh, voxel_size=1/64, min_bound=(-0.5, -0.5, -0.5), max_bound=(0.5, 0.5, 0.5)) vertices = np.array([voxel.grid_index for voxel in voxel_grid.get_voxels()]) assert np.all(vertices >= 0) and np.all(vertices < 64), "Some vertices are out of bounds" vertices = (vertices + 0.5) / 64 - 0.5 coords = ((torch.tensor(vertices) + 0.5) * resolution).int().contiguous() ss = torch.zeros(1, resolution, resolution, resolution, dtype=torch.long) ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1 if return_mask: ss_mask = rearrange(ss, 'c (x n1) (y n2) (z n3) -> (n1 n2 n3 c) x y z', n1=4, n2=4, n3=4).float() return ss , ss_mask else: return ss def transform_vertices(vertices, ops, params): for op, param in zip(ops, params): if op == 'scale': vertices = vertices * param elif op == 'translation': vertices = vertices + param else: raise NotImplementedError return vertices def normalize_vertices(vertices, scale_factor=1.0): min_pos, max_pos = np.min(vertices, axis=0), np.max(vertices, axis=0) trans_pos = (min_pos + max_pos)[None] / 2.0 scale_pos = np.max(max_pos - min_pos) * scale_factor # 1: [-0.5, 0.5], 2.0: [-0.25, 0.25] vertices = transform_vertices(vertices, ops=['translation', 'scale'], params=[-trans_pos, 1.0 / (scale_pos + 1e-6)]) return vertices, trans_pos, scale_pos def renormalize_vertices(vertices, val_range=0.5, scale_factor=1.25): min_pos, max_pos = np.min(vertices, axis=0), np.max(vertices, axis=0) if (min_pos < -val_range).any() or (max_pos > val_range).any(): trans_pos = (min_pos + max_pos)[None] / 2.0 scale_pos = np.max(max_pos - min_pos) * scale_factor # 1: [-0.5, 0.5], 2.0: [-0.25, 0.25] vertices = transform_vertices(vertices, ops=['translation', 'scale'], params=[-trans_pos, 1.0 / (scale_pos + 1e-6)]) return vertices def rot_vertices(vertices, rot_angles, axis_list=['z']): pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(vertices) for ang, axis in zip(rot_angles, axis_list): if axis == 'x': R = pcd.get_rotation_matrix_from_xyz((ang, 0, 0)) pcd.rotate(R, center=(0., 0., 0.)) del R elif axis == 'y': R = pcd.get_rotation_matrix_from_xyz((0, ang, 0)) pcd.rotate(R, center=(0., 0., 0.)) del R elif axis == 'z': R = pcd.get_rotation_matrix_from_xyz((0, 0, ang)) pcd.rotate(R, center=(0., 0., 0.)) del R else: raise NotImplementedError rot_vertices = np.array(pcd.points) del pcd return rot_vertices def _rotmat_x(a: torch.Tensor) -> torch.Tensor: # a: scalar tensor ca, sa = torch.cos(a), torch.sin(a) R = torch.stack([ torch.stack([torch.ones_like(a), torch.zeros_like(a), torch.zeros_like(a)]), torch.stack([torch.zeros_like(a), ca, -sa]), torch.stack([torch.zeros_like(a), sa, ca]), ]) return R # [3,3] def _rotmat_y(a: torch.Tensor) -> torch.Tensor: ca, sa = torch.cos(a), torch.sin(a) R = torch.stack([ torch.stack([ca, torch.zeros_like(a), sa]), torch.stack([torch.zeros_like(a), torch.ones_like(a), torch.zeros_like(a)]), torch.stack([-sa, torch.zeros_like(a), ca]), ]) return R def _rotmat_z(a: torch.Tensor) -> torch.Tensor: ca, sa = torch.cos(a), torch.sin(a) R = torch.stack([ torch.stack([ca, -sa, torch.zeros_like(a)]), torch.stack([sa, ca, torch.zeros_like(a)]), torch.stack([torch.zeros_like(a), torch.zeros_like(a), torch.ones_like(a)]), ]) return R def rot_vertices_torch(vertices, rot_angles, axis_list=('z',), center=(0.0, 0.0, 0.0)): """ vertices: (N,3) numpy or torch rot_angles: iterable of angles (radians), length matches axis_list axis_list: iterable like ['x','y','z'] (applied in order) center: rotation center, default origin (0,0,0), same as your Open3D code return: torch.Tensor (N,3) """ v = torch.as_tensor(vertices) device, dtype = v.device, v.dtype c = torch.tensor(center, device=device, dtype=dtype).view(1, 3) v = v - c # translate to center # Compose rotations in the same order as your for-loop: # Open3D effectively does v <- v @ R^T (for row-vector points). for ang, axis in zip(rot_angles, axis_list): a = torch.as_tensor(ang, device=device, dtype=dtype) if axis == 'x': R = _rotmat_x(a) elif axis == 'y': R = _rotmat_y(a) elif axis == 'z': R = _rotmat_z(a) else: raise NotImplementedError(f"Unknown axis {axis}") v = v @ R.T # match Open3D row-vector convention v = v + c return v def get_instance_mask(instance_mask_path): index_mask = imageio.v3.imread(instance_mask_path) index_mask = np.rint(index_mask.astype(np.float32) / 65535 * 100.0) # hand coded, max obj nums = 100 instance_list = np.unique(index_mask).astype(np.uint8) return index_mask, instance_list def get_gt_depth(gt_depth_path, metadata): gt_depth = imageio.v3.imread(gt_depth_path).astype(np.float32) / 65535. depth_min, depth_max = metadata['depth']['min'], metadata['depth']['max'] gt_depth = gt_depth * (depth_max - depth_min) + depth_min return torch.from_numpy(gt_depth).to(dtype=torch.float32) def get_est_depth(est_depth_path): npz = np.load(est_depth_path) est_depth = npz['depth'] est_depth_mask = npz['mask'] est_depth = torch.from_numpy(est_depth).to(dtype=torch.float32) ivalid_mask = torch.logical_or(torch.isnan(est_depth), torch.isinf(est_depth)) est_depth_mask = np.logical_and(est_depth_mask, ~ivalid_mask.detach().cpu().numpy()) est_depth = torch.where(ivalid_mask, 0.0, est_depth) return est_depth, est_depth_mask def get_mix_est_depth(est_depth_path, image_size): if 'MoGe' in est_depth_path: npz = np.load(est_depth_path) est_depth = npz['depth'] est_depth_mask = npz['mask'] est_depth = torch.from_numpy(est_depth).to(dtype=torch.float32) ivalid_mask = torch.logical_or(torch.isnan(est_depth), torch.isinf(est_depth)) est_depth_mask = np.logical_and(est_depth_mask, ~ivalid_mask.detach().cpu().numpy()) est_depth = torch.where(ivalid_mask, 0.0, est_depth) return est_depth, est_depth_mask elif 'DAv2_' in est_depth_path or 'ml-depth-pro' in est_depth_path: npz = np.load(est_depth_path) est_depth = npz['depth'] est_depth_mask = np.logical_not(np.logical_or( np.isnan(est_depth), np.isinf(est_depth), )) est_depth = torch.from_numpy(est_depth).to(dtype=torch.float32) ivalid_mask = torch.logical_or(torch.isnan(est_depth), torch.isinf(est_depth)) est_depth_mask = np.logical_and(est_depth_mask, ~ivalid_mask.detach().cpu().numpy()) est_depth = torch.where(ivalid_mask, 0.0, est_depth) return est_depth, est_depth_mask elif 'VGGT_1B' in est_depth_path: npz = np.load(est_depth_path) est_depth = npz['depth'] est_depth_mask = npz['depth_conf'] > 2.0 valid_depth_mask = np.logical_not(np.logical_or( np.isnan(est_depth), np.isinf(est_depth), )) est_depth_mask = np.logical_and( est_depth_mask, valid_depth_mask ) est_depth = np.where(valid_depth_mask, est_depth, 0.0) depth_min, depth_max = np.min(est_depth), np.max(est_depth) est_depth = (est_depth - depth_min) / (depth_max - depth_min + 1e-6) est_depth = Image.fromarray(est_depth) est_depth = est_depth.resize((image_size, image_size), Image.Resampling.NEAREST) est_depth = torch.tensor(np.array(est_depth)).to(dtype=torch.float32) est_depth = est_depth * (depth_max - depth_min) + depth_min est_depth_mask = Image.fromarray(est_depth_mask.astype(np.float32)) est_depth_mask = est_depth_mask.resize((image_size, image_size), Image.Resampling.NEAREST) est_depth_mask = np.array(est_depth_mask) > 0.5 ivalid_mask = torch.logical_or(torch.isnan(est_depth), torch.isinf(est_depth)) est_depth_mask = np.logical_and(est_depth_mask, ~ivalid_mask.detach().cpu().numpy()) est_depth = torch.where(ivalid_mask, 0.0, est_depth) return est_depth, est_depth_mask def lstsq_align_depth(est_depth, gt_depth, mask): valid_coords = torch.nonzero(mask) if valid_coords.shape[0] > 0: valid_gt_depth = gt_depth[valid_coords[:, 0], valid_coords[:, 1]] valid_est_depth = est_depth[valid_coords[:, 0], valid_coords[:, 1]] X = torch.linalg.lstsq(valid_est_depth[None, :, None], valid_gt_depth[None, :, None]).solution lstsq_scale = X.item() else: lstsq_scale = 1.0 return est_depth * lstsq_scale def get_cam_poses(frame_info, H, W): camera_angle_x = float(frame_info['camera_angle_x']) focal = .5 * W / np.tan(.5 * camera_angle_x) K = np.array([ [focal, 0, 0.5*W], [0, focal, 0.5*H], [0, 0, 1] ]) K = torch.from_numpy(K).float() c2w = torch.from_numpy(np.array(frame_info['transform_matrix'])).float() return K, c2w def edge_mask_morph_gradient(mask, kernel, iterations=1): """ mask: HxW, bool/uint8 ksize: 3/5/7... 越大边缘越厚 return: edge_mask uint8 {0,1} """ m = (mask.astype(np.uint8) > 0).astype(np.uint8) dil = cv2.dilate(m, kernel, iterations=iterations, borderType=cv2.BORDER_CONSTANT, borderValue=0.0) ero = cv2.erode(m, kernel, iterations=iterations, borderType=cv2.BORDER_CONSTANT, borderValue=0.0) edge = (dil - ero) # 0/1/2 edge = (edge > 0).astype(np.uint8) return edge def process_scene_image(image: Image.Image, instance_mask: np.ndarray, image_size: int, resize_perturb: bool = False, resize_perturb_ratio: float = 0.0): image_rgba = image try: alpha = np.array(image_rgba.getchannel("A")) > 0 except ValueError: alpha = np.ones_like(np.array(image_rgba.getchannel(0))) > 0 alpha = np.logical_and(alpha, instance_mask).astype(np.uint8) * 255 image_resized = image_rgba.resize((image_size, image_size), Image.Resampling.LANCZOS).convert("RGB") alpha_resized = Image.fromarray(alpha, mode="L").resize((image_size, image_size), Image.Resampling.NEAREST) if resize_perturb and np.random.rand() < resize_perturb_ratio: rand_reso = np.random.randint(32, image_size) image_resized = image_resized.resize((rand_reso, rand_reso), Image.Resampling.LANCZOS) image_resized = image_resized.resize((image_size, image_size), Image.Resampling.LANCZOS) alpha_resized = alpha_resized.resize((rand_reso, rand_reso), Image.Resampling.NEAREST) alpha_resized = alpha_resized.resize((image_size, image_size), Image.Resampling.NEAREST) img_np = np.array(image_resized, dtype=np.uint8) img_t = torch.from_numpy(img_np).permute(2, 0, 1).float() / 255.0 a_np = np.array(alpha_resized, dtype=np.uint8) a_t = torch.from_numpy(a_np).unsqueeze(0).float() / 255.0 img4 = torch.cat([img_t, a_t], dim=0) # (4,S,S) return img_t, img4 def get_rays(i, j, K, c2w): i = i.float() + 0.5 j = j.float() + 0.5 dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1) # Rotate ray directions from camera frame to the world frame rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] # Translate camera frame's origin to the world frame. It is the origin of all rays. rays_o = c2w[:3,-1].expand(rays_d.shape) return rays_o, rays_d def get_rays_fast(u: torch.Tensor, v: torch.Tensor, K: torch.Tensor, c2w: torch.Tensor): """ u, v: 1D tensor (pixel coords), dtype long/int64 or int32 K: (3,3) or (4,4) but used as 3x3; on same device as output c2w: (4,4) or (3,4), uses [:3,:3] and [:3,3] return: rays_o: (N,3) rays_d: (N,3) """ # 确保 float 并加 0.5 取像素中心 u = u.to(dtype=torch.float32) + 0.5 v = v.to(dtype=torch.float32) + 0.5 fx, fy = K[0, 0], K[1, 1] cx, cy = K[0, 2], K[1, 2] # dirs in camera frame (N,3) dirs = torch.stack([(u - cx) / fx, -(v - cy) / fy, -torch.ones_like(u)], dim=-1) # 旋转到世界坐标:dirs @ R^T (更常见/更快) R = c2w[:3, :3] # (3,3) rays_d = dirs @ R.T # (N,3) # 原点:相机中心 (3,) 扩展到 (N,3) t = c2w[:3, 3] rays_o = t.expand_as(rays_d) return rays_o, rays_d def process_instance_image(image: Image.Image, instance_mask: np.ndarray, color_mask: np.ndarray, depth_map: torch.Tensor, K: torch.Tensor, c2w: torch.Tensor, image_size: int): image_rgba = image try: alpha = np.asarray(image_rgba.getchannel("A")) > 0 except ValueError: alpha = np.ones_like(np.array(image_rgba.getchannel(0))) > 0 alpha = np.logical_and(alpha, instance_mask).astype(np.uint8) * 255 valid_mask = np.array(alpha).nonzero() bbox = [valid_mask[1].min(), valid_mask[0].min(), valid_mask[1].max(), valid_mask[0].max()] center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2] hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2 aug_size_ratio = 1.2 aug_hsize = hsize * aug_size_ratio aug_center_offset = [0, 0] aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]] aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)] i, j = torch.from_numpy(valid_mask[1]), torch.from_numpy(valid_mask[0]) rays_o, rays_d = get_rays(i, j, K, c2w) rays_color = color_mask[valid_mask[0], valid_mask[1]].astype(np.float32) rays_t = depth_map[valid_mask[0], valid_mask[1]] image_resized = image_rgba.crop(aug_bbox).convert("RGB").resize((image_size, image_size), Image.Resampling.LANCZOS) alpha_resized = Image.fromarray(alpha, mode="L").crop(aug_bbox).resize((image_size, image_size), Image.Resampling.NEAREST) img_np = np.asarray(image_resized, dtype=np.uint8) img_t = torch.from_numpy(img_np).permute(2, 0, 1).float() / 255.0 a_np = np.asarray(alpha_resized, dtype=np.uint8) a_t = torch.from_numpy(a_np).unsqueeze(0).float() / 255.0 return img_t, a_t, rays_o, rays_d, rays_color, rays_t def get_crop_area_rays(image: Image.Image, instance_mask: np.ndarray, K: torch.Tensor, c2w: torch.Tensor, image_size): alpha = np.asarray(image.getchannel("A")) > 0 if instance_mask is not None: alpha = np.logical_and(alpha, instance_mask).astype(np.float32) # * 255 else: alpha = alpha.astype(np.float32) valid_mask = np.array(alpha).nonzero() bbox = [valid_mask[1].min(), valid_mask[0].min(), valid_mask[1].max(), valid_mask[0].max()] center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2] hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2 aug_size_ratio = 1.2 aug_hsize = hsize * aug_size_ratio aug_center_offset = [0, 0] aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]] aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)] i, j = torch.meshgrid( torch.linspace(aug_bbox[0], aug_bbox[2]-1, steps=image_size), torch.linspace(aug_bbox[1], aug_bbox[3]-1, steps=image_size) ) rays_o, rays_d = get_rays(i, j, K, c2w) return rays_o, rays_d def process_instance_image_crop(image: Image.Image, instance_mask: np.ndarray, color_mask: np.ndarray, depth_map: torch.Tensor, gt_depth_map: torch.Tensor, K: torch.Tensor, c2w: torch.Tensor, image_size: int, edge_mask_morph_gradient_fn): image_rgba = image alpha = np.asarray(image_rgba.getchannel("A")) > 0 alpha = np.logical_and(alpha, instance_mask).astype(np.float32) # * 255 valid_mask = np.array(alpha).nonzero() bbox = [valid_mask[1].min(), valid_mask[0].min(), valid_mask[1].max(), valid_mask[0].max()] center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2] hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2 aug_size_ratio = 1.2 aug_hsize = hsize * aug_size_ratio aug_center_offset = [0, 0] aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]] aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)] i, j = torch.meshgrid( torch.linspace(aug_bbox[0], aug_bbox[2]-1, steps=image_size), torch.linspace(aug_bbox[1], aug_bbox[3]-1, steps=image_size) ) rays_o, rays_d = get_rays(i, j, K, c2w) image_resized = image_rgba.crop(aug_bbox).convert("RGB").resize((image_size, image_size), Image.Resampling.LANCZOS) alpha_resized = Image.fromarray(alpha, mode="F").crop(aug_bbox).resize((image_size, image_size), Image.Resampling.NEAREST) depth_map_resized = Image.fromarray(depth_map.detach().cpu().numpy(), mode="F").crop(aug_bbox).resize((image_size, image_size), Image.Resampling.NEAREST) gt_depth_map_resized = Image.fromarray(gt_depth_map.detach().cpu().numpy(), mode="F").crop(aug_bbox).resize((image_size, image_size), Image.Resampling.NEAREST) color_mask_resized = Image.fromarray(color_mask.astype(np.float32), mode="F").crop(aug_bbox).resize((image_size, image_size), Image.Resampling.NEAREST) img_np = np.asarray(image_resized, dtype=np.uint8) img_t = torch.from_numpy(img_np).permute(2, 0, 1).float() / 255.0 a_np = np.asarray(alpha_resized, dtype=np.float32).astype(dtype=np.uint8) edge_mask = edge_mask_morph_gradient_fn((a_np > 0).astype(np.uint8)) fg_mask = (a_np > edge_mask).astype(np.uint8) rays_color = fg_mask.astype(np.float32) + edge_mask.astype(np.float32) * 0.5 valid_mask = fg_mask.nonzero() rays_t = torch.from_numpy(np.asarray(depth_map_resized).astype(np.float32)) a_t = torch.from_numpy(a_np).unsqueeze(0).float() # / 255.0 return img_t, a_t, fg_mask, rays_o, rays_d, rays_color, rays_t, valid_mask, depth_map_resized, gt_depth_map_resized, color_mask_resized def process_instance_image_only(image: Image.Image, instance_mask: np.ndarray, image_size: int): image_rgba = image alpha = np.asarray(image_rgba.getchannel("A")) > 0 alpha = np.logical_and(alpha, instance_mask).astype(np.uint8) * 255 valid_mask = np.array(alpha).nonzero() bbox = [valid_mask[1].min(), valid_mask[0].min(), valid_mask[1].max(), valid_mask[0].max()] center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2] hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2 aug_size_ratio = 1.2 aug_hsize = hsize * aug_size_ratio aug_center_offset = [0, 0] aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]] aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)] image_resized = image_rgba.crop(aug_bbox).convert("RGB").resize((image_size, image_size), Image.Resampling.LANCZOS) alpha_resized = Image.fromarray(alpha, mode="L").crop(aug_bbox).resize((image_size, image_size), Image.Resampling.NEAREST) img_np = np.asarray(image_resized, dtype=np.uint8) img_t = torch.from_numpy(img_np).permute(2, 0, 1).float() / 255.0 a_np = np.asarray(alpha_resized, dtype=np.uint8) a_t = torch.from_numpy(a_np).unsqueeze(0).float() / 255.0 return img_t, a_t def crop_depth_image(depth_image, aug_bbox, image_size): d = depth_image.cpu() d_np = d.numpy().astype(np.float32) img = Image.fromarray(d_np, mode="F") img = img.crop(aug_bbox) img = img.resize((image_size, image_size), Image.Resampling.NEAREST) out = torch.from_numpy(np.asarray(img, dtype=np.float32)) return out def proj_depth2pcd(mask, depth, image, rays_o, rays_d): mask = torch.nonzero(mask) ### mask = [mask[:, 0].detach().cpu().numpy(), mask[:, 1].detach().cpu().numpy()] pixel_depth = depth[mask[0], mask[1]] pixel_color = image.detach().permute(1, 2, 0)[mask[0], mask[1]] pixel_points = rays_o[mask[0], mask[1]] + rays_d[mask[0], mask[1]] * pixel_depth[:, None] # pt return pixel_points.detach().cpu().numpy(), pixel_color.detach().cpu().numpy() def vox2pts(ss, resolution = 64): coords = torch.nonzero(ss[0] > 0, as_tuple=False) position = (coords.float() + 0.5) / resolution - 0.5 position = position.detach().cpu().numpy() return position def voxelize_pcd(points, points_color=None, clip_range_first=False, return_mask=True, resolution=64): if clip_range_first: points = np.clip(points, -0.5 + 1e-6, 0.5 - 1e-6) pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(points) voxel_grid = o3d.geometry.VoxelGrid.create_from_point_cloud_within_bounds(pcd, voxel_size=1/resolution, min_bound=(-0.5, -0.5, -0.5), max_bound=(0.5, 0.5, 0.5)) vertices = np.array([voxel.grid_index for voxel in voxel_grid.get_voxels()]) assert np.all(vertices >= 0) and np.all(vertices < resolution), "Some vertices are out of bounds" vertices = (vertices + 0.5) / resolution - 0.5 coords = ((torch.tensor(vertices) + 0.5) * resolution).int().contiguous() ss = torch.zeros(1, resolution, resolution, resolution, dtype=torch.long) ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1 if points_color is not None: points_t = torch.from_numpy(points).to(torch.float32) colors_t = torch.from_numpy(points_color).to(torch.float32) coords = torch.floor((points_t + 0.5) * resolution).to(torch.long) coords = torch.clamp(coords, 0, resolution - 1) ix, iy, iz = coords[:, 0], coords[:, 1], coords[:, 2] lin = ix * (resolution * resolution) + iy * resolution + iz # linear index in [0, R^3) sum_color = torch.zeros((resolution * resolution * resolution), dtype=torch.float32) sum_color.index_add_(0, lin, colors_t) count = torch.zeros((resolution * resolution * resolution,), dtype=torch.long) ones = torch.ones_like(lin, dtype=torch.long) count.index_add_(0, lin, ones) count_f = count.to(torch.float32) mean_color = sum_color / torch.clamp(count_f, min=1.0) # empty -> divide by 1 (still 0) color_mean = mean_color.view(resolution, resolution, resolution, 1).permute(3, 0, 1, 2).contiguous() if return_mask: ss_mask = rearrange(ss if points_color is None else color_mean, 'c (x n1) (y n2) (z n3) -> (n1 n2 n3 c) x y z', n1=4, n2=4, n3=4).float() return ss , ss_mask else: return ss def voxelize_pcd_pt(points, points_color=None, clip_range_first=False, return_mask=True, resolution=64): points = torch.nan_to_num(points) points_color = torch.nan_to_num(points_color) if isinstance(points_color, torch.Tensor) else points_color device = points.device if clip_range_first: points = torch.clip(points, -0.5 + 1e-6, 0.5 - 1e-6) pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(points.detach().cpu().numpy()) voxel_grid = o3d.geometry.VoxelGrid.create_from_point_cloud_within_bounds(pcd, voxel_size=1/resolution, min_bound=(-0.5, -0.5, -0.5), max_bound=(0.5, 0.5, 0.5)) vertices = np.array([voxel.grid_index for voxel in voxel_grid.get_voxels()]) assert np.all(vertices >= 0) and np.all(vertices < resolution), "Some vertices are out of bounds" vertices = (vertices + 0.5) / resolution - 0.5 coords = ((torch.tensor(vertices, device=device) + 0.5) * resolution).int().contiguous() ss = torch.zeros(1, resolution, resolution, resolution, dtype=torch.long, device=device) ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1 if points_color is not None: points_t = points.to(torch.float32) colors_t = points_color.to(torch.float32) coords = torch.floor((points_t + 0.5) * resolution).to(torch.long) coords = torch.clamp(coords, 0, resolution - 1) ix, iy, iz = coords[:, 0], coords[:, 1], coords[:, 2] lin = ix * (resolution * resolution) + iy * resolution + iz # linear index in [0, R^3) sum_color = torch.zeros((resolution * resolution * resolution), dtype=torch.float32, device=device) sum_color.index_add_(0, lin, colors_t) count = torch.zeros((resolution * resolution * resolution,), dtype=torch.long, device=device) ones = torch.ones_like(lin, dtype=torch.long) count.index_add_(0, lin, ones) count_f = count.to(torch.float32) mean_color = sum_color / torch.clamp(count_f, min=1.0) # empty -> divide by 1 (still 0) color_mean = mean_color.view(resolution, resolution, resolution, 1).permute(3, 0, 1, 2).contiguous() if return_mask: ss_mask = rearrange(ss if points_color is None else color_mean, 'c (x n1) (y n2) (z n3) -> (n1 n2 n3 c) x y z', n1=4, n2=4, n3=4).float() return ss , ss_mask else: return ss def get_std_cond(root, instance, crop_size, return_mask=False): image_root = os.path.join(root, 'renders_cond', instance) if os.path.exists(os.path.join(image_root, 'transforms.json')): with open(os.path.join(image_root, 'transforms.json')) as f: metadata = json.load(f) else: image_root = os.path.join(root, 'renders', instance) with open(os.path.join(image_root, 'transforms.json')) as f: metadata = json.load(f) n_views = len(metadata['frames']) view = np.random.randint(n_views) metadata = metadata['frames'][view] image_path = os.path.join(image_root, metadata['file_path']) image = Image.open(image_path) alpha = np.array(image.getchannel(3)) bbox = np.array(alpha).nonzero() bbox = [bbox[1].min(), bbox[0].min(), bbox[1].max(), bbox[0].max()] center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2] hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2 aug_size_ratio = 1.2 aug_hsize = hsize * aug_size_ratio aug_center_offset = [0, 0] aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]] aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)] image = image.crop(aug_bbox) image = image.resize((crop_size, crop_size), Image.Resampling.LANCZOS) alpha = image.getchannel(3) image = image.convert('RGB') image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0 alpha = torch.tensor(np.array(alpha)).float() / 255.0 image = image * alpha.unsqueeze(0) if return_mask: return image, alpha.unsqueeze(0) else: return image def map_rotated_slat2canonical_pose(vertices, rot_slat_info): vertices_scale = rot_slat_info['scale'] vertices_trans = np.array(rot_slat_info['translation']) rand_rot = rot_slat_info['rotate'] pcd = o3d.geometry.PointCloud() vertices = vertices * vertices_scale vertices = vertices + vertices_trans pcd.points = o3d.utility.Vector3dVector(vertices) R1 = pcd.get_rotation_matrix_from_xyz((-rand_rot[0], 0, 0)) R2 = pcd.get_rotation_matrix_from_xyz((0, -rand_rot[1], 0)) R3 = pcd.get_rotation_matrix_from_xyz((0, 0, -rand_rot[2])) pcd.rotate(R3, center=(0., 0., 0.)) pcd.rotate(R2, center=(0., 0., 0.)) pcd.rotate(R1, center=(0., 0., 0.)) vertices = np.asarray(pcd.points) return vertices def project2ply(mask, depth, image, K, c2w): mask = torch.nonzero(mask) rays_o, rays_d = get_rays(mask[:, 1], mask[:, 0], K, c2w) ### mask = [mask[:, 0].detach().cpu().numpy(), mask[:, 1].detach().cpu().numpy()] pixel_depth = depth[mask[0], mask[1]] pixel_color = image.detach().permute(1, 2, 0).cpu().numpy()[mask[0], mask[1]] pixel_points = rays_o + rays_d * pixel_depth[:, None] pixel_points = pixel_points.detach().cpu().numpy() return pixel_points, pixel_color