| import time |
| import threading |
| import numpy as np |
| import open3d as o3d |
| import viser |
| from pathlib import Path |
| from typing import List, Optional, Tuple |
| import matplotlib.pyplot as plt |
| import logging |
|
|
| |
| logging.getLogger("websockets.server").setLevel(logging.ERROR) |
| logging.getLogger("asyncio").setLevel(logging.ERROR) |
| |
| from huggingface_hub import snapshot_download |
|
|
| |
| |
| |
| SERVER_PORT = 7860 |
| DATA_ROOT = Path("./assets") |
| MAX_FRAMES = 400 |
| DATASET_REPO_ID = "cyun9286/Holi4d_demo" |
|
|
| |
| |
| |
|
|
| def remove_radius_outlier_open3d(points: np.ndarray, nb_neighbors: int = 30, std_ratio: float = 2.0) -> Tuple[np.ndarray, np.ndarray]: |
| """使用 Open3D (CPU) 移除离群点""" |
| if points.shape[0] == 0: |
| return points, np.array([]) |
| pcd = o3d.geometry.PointCloud() |
| pcd.points = o3d.utility.Vector3dVector(points) |
| |
| pcd_filtered, ind = pcd.remove_statistical_outlier(nb_neighbors=nb_neighbors, std_ratio=std_ratio) |
| points_filtered = np.asarray(pcd_filtered.points) |
| return points_filtered, np.array(ind) |
|
|
| def compute_trajectory_colors(trajectories: np.ndarray, mask: np.ndarray) -> np.ndarray: |
| """根据首次出现的位置计算彩虹色""" |
| N = trajectories.shape[1] |
| |
| |
| first_visible_idx = np.argmax(mask, axis=0) |
| never_visible = ~np.any(mask, axis=0) |
| first_visible_idx[never_visible] = 0 |
| |
| indices = np.arange(N) |
| first_visible_xyz = trajectories[first_visible_idx, indices] |
| first_visible_xyz[never_visible] = np.nan |
| |
| |
| xyz_min = np.nanmin(first_visible_xyz, axis=0) |
| xyz_max = np.nanmax(first_visible_xyz, axis=0) |
| |
| denom = xyz_max - xyz_min |
| denom[denom == 0] = 1.0 |
| xyz_norm = (first_visible_xyz - xyz_min) / denom |
| |
| |
| scalar = np.nansum(xyz_norm, axis=1) |
| s_min, s_max = np.nanmin(scalar), np.nanmax(scalar) |
| if s_max == s_min: |
| scalar_norm = scalar |
| else: |
| scalar_norm = (scalar - s_min) / (s_max - s_min) |
| |
| |
| sort_idx = np.argsort(scalar_norm) |
| colors_hsv = plt.cm.hsv(np.linspace(0, 1, N))[:, :3] |
| |
| |
| final_colors = np.zeros((N, 3)) |
| final_colors[sort_idx] = colors_hsv |
| |
| return (final_colors * 255).astype(np.uint8) |
|
|
| |
| |
| |
| class GlobalState: |
| def __init__(self): |
| self.lock = threading.Lock() |
| self.is_loaded = False |
| self.num_frames = 0 |
| |
| |
| self.point_nodes: List[viser.SceneNodeHandle] = [] |
| self.line_node: Optional[viser.SceneNodeHandle] = None |
| |
| |
| self.trajectory_raw_all = None |
| self.visibility_mask_raw_all = None |
| |
| |
| self.current_downsample = -1 |
| self.trajectories_3d = None |
| self.visibility_mask = None |
| self.initial_colors = None |
| |
| |
| self.history_lines_pos = [] |
| self.history_lines_col = [] |
|
|
| state = GlobalState() |
|
|
| |
| |
| |
| def download_data(): |
| print(f"--- 开始下载数据: {DATASET_REPO_ID} ---") |
| try: |
| snapshot_download( |
| repo_id=DATASET_REPO_ID, |
| repo_type="dataset", |
| local_dir=DATA_ROOT, |
| local_dir_use_symlinks=False, |
| resume_download=True |
| ) |
| print("--- 数据下载完成 ---") |
| except Exception as e: |
| print(f"!!! 数据下载失败: {e}") |
|
|
| def clear_scene(server: viser.ViserServer): |
| with state.lock: |
| state.is_loaded = False |
| for node in state.point_nodes: |
| node.remove() |
| state.point_nodes.clear() |
| |
| if state.line_node is not None: |
| state.line_node.remove() |
| state.line_node = None |
| |
| state.history_lines_pos = [] |
| state.history_lines_col = [] |
| state.num_frames = 0 |
| state.trajectory_raw_all = None |
|
|
| def load_scene_data(server: viser.ViserServer, scene_name: str, gui_line_width_val: float): |
| scene_path = DATA_ROOT / scene_name |
| print(f"正在加载场景: {scene_name} ...") |
| |
| loading_node = server.scene.add_label("loading", f"Loading {scene_name}...\nPlease wait.", position=(0,0,0)) |
| clear_scene(server) |
| |
| |
| ply_files = sorted([f for f in scene_path.glob("frame_*.ply")], key=lambda x: int(x.stem.split("_")[-1])) |
| ply_files = ply_files[:MAX_FRAMES] |
| |
| if not ply_files: |
| print("未找到 ply 文件") |
| loading_node.remove() |
| return |
|
|
| |
| traj_path = scene_path / 'trajectory_all_pointmap.npy' |
| has_traj = traj_path.exists() |
| |
| traj_raw = None |
| vis_mask_raw = None |
|
|
| if has_traj: |
| print("加载并清洗轨迹数据 (这可能需要一点时间)...") |
| try: |
| traj_raw = np.load(str(traj_path)) |
| |
| traj_raw = traj_raw[:len(ply_files)] |
| |
| vis_mask_raw = ~np.isnan(traj_raw).any(axis=-1) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| print("轨迹数据处理完成") |
| except Exception as e: |
| print(f"轨迹加载失败: {e}") |
| has_traj = False |
|
|
| |
| new_nodes = [] |
| print(f"加载 {len(ply_files)} 帧点云...") |
| for i, ply_file in enumerate(ply_files): |
| try: |
| pcd = o3d.io.read_point_cloud(str(ply_file)) |
| |
| |
| |
| |
| |
| points = np.asarray(pcd.points) * 100 * np.array([1, -1, -1]) |
| colors = np.asarray(pcd.colors) |
| |
| node = server.scene.add_point_cloud( |
| name=f"/scene/frame_{i}", |
| points=points, |
| colors=(colors * 255).astype(np.uint8), |
| point_size=0.01, |
| point_shape="rounded", |
| visible=False |
| ) |
| new_nodes.append(node) |
| except Exception as e: |
| print(f"帧 {i} 加载失败: {e}") |
|
|
| |
| line_node = None |
| if has_traj: |
| line_node = server.scene.add_line_segments( |
| name="/scene/trajectories", |
| points=np.zeros((0, 2, 3)), |
| colors=np.zeros((0, 2, 3), dtype=np.uint8), |
| line_width=gui_line_width_val, |
| visible=True |
| ) |
|
|
| |
| with state.lock: |
| state.point_nodes = new_nodes |
| state.num_frames = len(new_nodes) |
| state.trajectory_raw_all = traj_raw |
| state.visibility_mask_raw_all = vis_mask_raw |
| state.line_node = line_node |
| state.current_downsample = -1 |
| state.is_loaded = True |
| |
| loading_node.remove() |
| print(f"场景加载完毕。") |
|
|
| |
| |
| |
| def main(): |
| download_data() |
|
|
| print(f"启动 Viser 服务器,端口: {SERVER_PORT}...") |
| server = viser.ViserServer(host="0.0.0.0", port=SERVER_PORT) |
| |
| if not DATA_ROOT.exists(): |
| server.scene.add_label("err", "Error: Data folder empty", position=(0,0,0)) |
| while True: time.sleep(1) |
|
|
| |
| scene_names = sorted([ |
| d.name for d in DATA_ROOT.iterdir() |
| if d.is_dir() and not d.name.startswith(".") |
| ]) |
| |
| if not scene_names: |
| server.scene.add_label("err", "No scenes found", position=(0,0,0)) |
| while True: time.sleep(1) |
|
|
| |
| with server.gui.add_folder("Scene Control"): |
| gui_scene_select = server.gui.add_dropdown("Select Scene", options=scene_names, initial_value=scene_names[0]) |
| |
| with server.gui.add_folder("Playback Controls"): |
| gui_playing = server.gui.add_checkbox("Playing", True) |
| gui_framerate = server.gui.add_slider("FPS", min=1, max=60, step=0.1, initial_value=24) |
| gui_timestep = server.gui.add_slider("Timestep", min=0, max=100, step=1, initial_value=0) |
| gui_point_size = server.gui.add_slider("Point size", min=0.001, max=0.05, step=0.001, initial_value=0.01) |
| gui_line_width = server.gui.add_slider("Line width", min=0.1, max=5.0, step=0.1, initial_value=0.5) |
| gui_max_traj_length = server.gui.add_slider("Trail Length", min=1, max=50, step=1, initial_value=5) |
| gui_downsample = server.gui.add_slider("Downsample", min=1, max=100, step=1, initial_value=100) |
| gui_vis_mode = server.gui.add_button_group("Vis Mode", ("PointCloud", "Tracking", "Both")) |
| gui_vis_mode.value = "Both" |
|
|
| |
| @gui_scene_select.on_update |
| def _(_): |
| load_scene_data(server, gui_scene_select.value, gui_line_width.value) |
| if state.num_frames > 0: |
| gui_timestep.max = state.num_frames - 1 |
| gui_timestep.value = 0 |
|
|
| |
| load_scene_data(server, scene_names[0], gui_line_width.value) |
| gui_timestep.max = max(1, state.num_frames - 1) |
|
|
| |
| prev_timestep = -1 |
| |
| while True: |
| if not state.is_loaded or state.num_frames == 0: |
| time.sleep(0.1) |
| continue |
|
|
| |
| if gui_playing.value: |
| next_step = (gui_timestep.value + 1) % state.num_frames |
| gui_timestep.value = next_step |
| |
| t_curr = gui_timestep.value |
| |
| |
| with state.lock: |
| if state.trajectory_raw_all is not None and gui_downsample.value != state.current_downsample: |
| state.current_downsample = gui_downsample.value |
| |
| |
| state.trajectories_3d = state.trajectory_raw_all[:, ::state.current_downsample] |
| state.visibility_mask = state.visibility_mask_raw_all[:, ::state.current_downsample] |
| |
| |
| state.initial_colors = compute_trajectory_colors(state.trajectories_3d, state.visibility_mask) |
| |
| |
| state.history_lines_pos = [] |
| state.history_lines_col = [] |
| prev_timestep = -1 |
|
|
| |
| show_points = gui_vis_mode.value in ("PointCloud", "Both") |
| show_lines = gui_vis_mode.value in ("Tracking", "Both") and (state.line_node is not None) |
|
|
| |
| if t_curr != prev_timestep: |
| |
| if 0 <= prev_timestep < len(state.point_nodes): |
| state.point_nodes[prev_timestep].visible = False |
| |
| |
| if 0 <= t_curr < len(state.point_nodes): |
| node = state.point_nodes[t_curr] |
| if show_points: |
| node.visible = True |
| node.point_size = gui_point_size.value |
| else: |
| node.visible = False |
|
|
| |
| if show_lines: |
| state.line_node.visible = True |
| state.line_node.line_width = gui_line_width.value |
| |
| |
| if t_curr == 0: |
| state.history_lines_pos = [] |
| state.history_lines_col = [] |
| state.line_node.points = np.zeros((0, 2, 3)) |
| |
| |
| elif t_curr > 0 and t_curr < state.num_frames: |
| t_prev = t_curr - 1 |
| |
| |
| |
| pos_prev = state.trajectories_3d[t_prev] * 100 |
| pos_curr = state.trajectories_3d[t_curr] * 100 |
| |
| |
| valid_mask = state.visibility_mask[t_prev] & state.visibility_mask[t_curr] |
| |
| if np.any(valid_mask): |
| p1 = pos_prev[valid_mask] |
| p2 = pos_curr[valid_mask] |
| |
| |
| dist = np.linalg.norm(p2 - p1, axis=1) |
| jump_mask = dist < 1.0 |
| |
| if np.any(jump_mask): |
| final_p1 = p1[jump_mask] |
| final_p2 = p2[jump_mask] |
| |
| segments = np.stack([final_p1, final_p2], axis=1) |
| |
| |
| cols = state.initial_colors[valid_mask][jump_mask] |
| segment_colors = np.stack([cols, cols], axis=1) |
| |
| state.history_lines_pos.append(segments) |
| state.history_lines_col.append(segment_colors) |
| |
| |
| max_len = int(gui_max_traj_length.value) |
| while len(state.history_lines_pos) > max_len: |
| state.history_lines_pos.pop(0) |
| state.history_lines_col.pop(0) |
| |
| |
| if state.history_lines_pos: |
| all_pos = np.concatenate(state.history_lines_pos, axis=0) |
| all_col = np.concatenate(state.history_lines_col, axis=0) |
| state.line_node.points = all_pos |
| state.line_node.colors = all_col |
| else: |
| state.line_node.points = np.zeros((0, 2, 3)) |
| |
| elif state.line_node: |
| state.line_node.visible = False |
|
|
| prev_timestep = t_curr |
| time.sleep(1.0 / gui_framerate.value) |
|
|
| if __name__ == "__main__": |
| main() |