Spaces:
Paused
Paused
| import os | |
| from typing import NamedTuple | |
| import numpy as np | |
| import json | |
| from plyfile import PlyData, PlyElement | |
| from utils.sh_utils import SH2RGB | |
| from scene.gaussian_model import BasicPointCloud | |
| import torch.nn.functional as F | |
| from imageio.v2 import imread | |
| import torch | |
| class CameraInfo(NamedTuple): | |
| K: np.array | |
| c2w: np.array | |
| image: np.array | |
| image_path: str | |
| image_name: str | |
| width: int | |
| height: int | |
| semantic2d: np.array | |
| optical_image: np.array | |
| depth: torch.tensor | |
| mask: np.array | |
| timestamp: int | |
| dynamics: dict | |
| class SceneInfo(NamedTuple): | |
| point_cloud: BasicPointCloud | |
| train_cameras: list | |
| test_cameras: list | |
| nerf_normalization: dict | |
| ply_path: str | |
| verts: dict | |
| def getNerfppNorm(cam_info, data_type): | |
| def get_center_and_diag(cam_centers): | |
| cam_centers = np.hstack(cam_centers) | |
| avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True) | |
| center = avg_cam_center | |
| dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True) | |
| diagonal = np.max(dist) | |
| return center.flatten(), diagonal | |
| cam_centers = [] | |
| for cam in cam_info: | |
| cam_centers.append(cam.c2w[:3, 3:4]) # cam_centers in world coordinate | |
| radius = 10 | |
| return {'radius': radius} | |
| def fetchPly(path): | |
| plydata = PlyData.read(path) | |
| vertices = plydata['vertex'] | |
| positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T | |
| if 'red' in vertices: | |
| colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0 | |
| else: | |
| print('Create random colors') | |
| shs = np.ones((positions.shape[0], 3)) * 0.5 | |
| colors = SH2RGB(shs) | |
| normals = np.zeros((positions.shape[0], 3)) | |
| return BasicPointCloud(points=positions, colors=colors, normals=normals) | |
| def storePly(path, xyz, rgb): | |
| # Define the dtype for the structured array | |
| dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), | |
| ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), | |
| ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] | |
| normals = np.zeros_like(xyz) | |
| elements = np.empty(xyz.shape[0], dtype=dtype) | |
| attributes = np.concatenate((xyz, normals, rgb), axis=1) | |
| elements[:] = list(map(tuple, attributes)) | |
| # Create the PlyData object and write to file | |
| vertex_element = PlyElement.describe(elements, 'vertex') | |
| ply_data = PlyData([vertex_element]) | |
| ply_data.write(path) | |
| def readHUGSIMCameras(path, data_type, ignore_dynamic): | |
| train_cam_infos, test_cam_infos = [], [] | |
| with open(os.path.join(path, 'meta_data.json')) as json_file: | |
| meta_data = json.load(json_file) | |
| verts = {} | |
| if 'verts' in meta_data and not ignore_dynamic: | |
| verts_list = meta_data['verts'] | |
| for k, v in verts_list.items(): | |
| verts[k] = np.array(v) | |
| frames = meta_data['frames'] | |
| for idx, frame in enumerate(frames): | |
| c2w = np.array(frame['camtoworld']) | |
| rgb_path = os.path.join(path, frame['rgb_path'].replace('./', '')) | |
| rgb_split = rgb_path.split('/') | |
| image_name = '_'.join([rgb_split[-2], rgb_split[-1][:-4]]) | |
| image = imread(rgb_path) | |
| semantic_2d = None | |
| semantic_pth = rgb_path.replace("images", "semantics").replace('.png', '.npy').replace('.jpg', '.npy') | |
| if os.path.exists(semantic_pth): | |
| semantic_2d = np.load(semantic_pth) | |
| semantic_2d[(semantic_2d == 14) | (semantic_2d == 15)] = 13 | |
| optical_path = rgb_path.replace("images", "flow").replace('.png', '_flow.npy').replace('.jpg', '_flow.npy') | |
| if os.path.exists(optical_path): | |
| optical_image = np.load(optical_path) | |
| else: | |
| optical_image = None | |
| depth_path = rgb_path.replace("images", "depth").replace('.png', '.pt').replace('.jpg', '.pt') | |
| if os.path.exists(depth_path): | |
| depth = torch.load(depth_path, weights_only=True) | |
| else: | |
| depth = None | |
| mask = None | |
| mask_path = rgb_path.replace("images", "masks").replace('.png', '.npy').replace('.jpg', '.npy') | |
| if os.path.exists(mask_path): | |
| mask = np.load(mask_path) | |
| timestamp = frame.get('timestamp', -1) | |
| intrinsic = np.array(frame['intrinsics']) | |
| dynamics = {} | |
| if 'dynamics' in frame and not ignore_dynamic: | |
| dynamics_list = frame['dynamics'] | |
| for iid in dynamics_list.keys(): | |
| dynamics[iid] = torch.tensor(dynamics_list[iid]).cuda() | |
| cam_info = CameraInfo(K=intrinsic, c2w=c2w, image=np.array(image), | |
| image_path=rgb_path, image_name=image_name, height=image.shape[0], | |
| width=image.shape[1], semantic2d=semantic_2d, | |
| optical_image=optical_image, depth=depth, mask=mask, timestamp=timestamp, dynamics=dynamics) | |
| if data_type == 'kitti360': | |
| if idx < 20: | |
| train_cam_infos.append(cam_info) | |
| elif idx % 20 < 16: | |
| train_cam_infos.append(cam_info) | |
| elif idx % 20 >= 16: | |
| test_cam_infos.append(cam_info) | |
| else: | |
| continue | |
| elif data_type == 'kitti': | |
| if idx < 10 or idx >= len(frames) - 4: | |
| train_cam_infos.append(cam_info) | |
| elif idx % 4 < 2: | |
| train_cam_infos.append(cam_info) | |
| elif idx % 4 == 2: | |
| test_cam_infos.append(cam_info) | |
| else: | |
| continue | |
| elif data_type == "nuscenes": | |
| if idx % 30 >= 24: | |
| test_cam_infos.append(cam_info) | |
| else: | |
| train_cam_infos.append(cam_info) | |
| elif data_type == "waymo": | |
| if idx % 15 >= 12: | |
| test_cam_infos.append(cam_info) | |
| else: | |
| train_cam_infos.append(cam_info) | |
| elif data_type == "pandaset": | |
| if idx > 30 and idx % 30 >= 24: | |
| test_cam_infos.append(cam_info) | |
| else: | |
| train_cam_infos.append(cam_info) | |
| else: | |
| raise NotImplementedError | |
| return train_cam_infos, test_cam_infos, verts | |
| def readHUGSIMInfo(path, data_type, ignore_dynamic): | |
| train_cam_infos, test_cam_infos, verts = readHUGSIMCameras(path, data_type, ignore_dynamic) | |
| print(f'Loaded {len(train_cam_infos)} train cameras and {len(test_cam_infos)} test cameras') | |
| nerf_normalization = getNerfppNorm(train_cam_infos, data_type) | |
| ply_path = os.path.join(path, "points3d.ply") | |
| if not os.path.exists(ply_path): | |
| assert False, "Requires for initialize 3d points as inputs" | |
| try: | |
| pcd = fetchPly(ply_path) | |
| except Exception as e: | |
| print('When loading point clound, meet error:', e) | |
| exit(0) | |
| scene_info = SceneInfo(point_cloud=pcd, | |
| train_cameras=train_cam_infos, | |
| test_cameras=test_cam_infos, | |
| nerf_normalization=nerf_normalization, | |
| ply_path=ply_path, | |
| verts=verts) | |
| return scene_info | |
| sceneLoadTypeCallbacks = { | |
| "HUGSIM": readHUGSIMInfo, | |
| } |