Holi4D / app.py
leojiahlu
Update app code
9ce9d4a
import time
import threading
import numpy as np
import open3d as o3d
import viser
from pathlib import Path
from typing import List, Optional, Tuple
import matplotlib.pyplot as plt
import logging
# 屏蔽 websockets 的握手错误日志
logging.getLogger("websockets.server").setLevel(logging.ERROR)
logging.getLogger("asyncio").setLevel(logging.ERROR)
# 引入 Hugging Face 下载工具
from huggingface_hub import snapshot_download
# ==============================================================================
# 1. 配置区域
# ==============================================================================
SERVER_PORT = 7860
DATA_ROOT = Path("./assets")
MAX_FRAMES = 400 # 限制最大帧数
DATASET_REPO_ID = "cyun9286/Holi4d_demo"
# ==============================================================================
# 2. 辅助计算函数
# ==============================================================================
def remove_radius_outlier_open3d(points: np.ndarray, nb_neighbors: int = 30, std_ratio: float = 2.0) -> Tuple[np.ndarray, np.ndarray]:
"""使用 Open3D (CPU) 移除离群点"""
if points.shape[0] == 0:
return points, np.array([])
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points)
# Open3D returns (filtered_pcd, indices_list)
pcd_filtered, ind = pcd.remove_statistical_outlier(nb_neighbors=nb_neighbors, std_ratio=std_ratio)
points_filtered = np.asarray(pcd_filtered.points)
return points_filtered, np.array(ind)
def compute_trajectory_colors(trajectories: np.ndarray, mask: np.ndarray) -> np.ndarray:
"""根据首次出现的位置计算彩虹色"""
N = trajectories.shape[1]
# 找到每个点第一次可见的帧索引
first_visible_idx = np.argmax(mask, axis=0)
never_visible = ~np.any(mask, axis=0)
first_visible_idx[never_visible] = 0
indices = np.arange(N)
first_visible_xyz = trajectories[first_visible_idx, indices]
first_visible_xyz[never_visible] = np.nan
# 归一化位置
xyz_min = np.nanmin(first_visible_xyz, axis=0)
xyz_max = np.nanmax(first_visible_xyz, axis=0)
# 避免除以0
denom = xyz_max - xyz_min
denom[denom == 0] = 1.0
xyz_norm = (first_visible_xyz - xyz_min) / denom
# 创建标量分数用于排序/着色
scalar = np.nansum(xyz_norm, axis=1)
s_min, s_max = np.nanmin(scalar), np.nanmax(scalar)
if s_max == s_min:
scalar_norm = scalar
else:
scalar_norm = (scalar - s_min) / (s_max - s_min)
# 基于标量分数分配 HSV 颜色
sort_idx = np.argsort(scalar_norm)
colors_hsv = plt.cm.hsv(np.linspace(0, 1, N))[:, :3]
# 映射回原始索引
final_colors = np.zeros((N, 3))
final_colors[sort_idx] = colors_hsv
return (final_colors * 255).astype(np.uint8)
# ==============================================================================
# 3. 全局状态管理
# ==============================================================================
class GlobalState:
def __init__(self):
self.lock = threading.Lock()
self.is_loaded = False
self.num_frames = 0
# 场景对象
self.point_nodes: List[viser.SceneNodeHandle] = []
self.line_node: Optional[viser.SceneNodeHandle] = None
# 原始数据
self.trajectory_raw_all = None # [T, N, 3]
self.visibility_mask_raw_all = None # [T, N]
# 当前降采样后的缓存数据
self.current_downsample = -1
self.trajectories_3d = None
self.visibility_mask = None
self.initial_colors = None
# 轨迹绘制历史缓存 (用于拖尾效果)
self.history_lines_pos = []
self.history_lines_col = []
state = GlobalState()
# ==============================================================================
# 4. 数据加载逻辑
# ==============================================================================
def download_data():
print(f"--- 开始下载数据: {DATASET_REPO_ID} ---")
try:
snapshot_download(
repo_id=DATASET_REPO_ID,
repo_type="dataset",
local_dir=DATA_ROOT,
local_dir_use_symlinks=False,
resume_download=True
)
print("--- 数据下载完成 ---")
except Exception as e:
print(f"!!! 数据下载失败: {e}")
def clear_scene(server: viser.ViserServer):
with state.lock:
state.is_loaded = False
for node in state.point_nodes:
node.remove()
state.point_nodes.clear()
if state.line_node is not None:
state.line_node.remove()
state.line_node = None
state.history_lines_pos = []
state.history_lines_col = []
state.num_frames = 0
state.trajectory_raw_all = None
def load_scene_data(server: viser.ViserServer, scene_name: str, gui_line_width_val: float):
scene_path = DATA_ROOT / scene_name
print(f"正在加载场景: {scene_name} ...")
loading_node = server.scene.add_label("loading", f"Loading {scene_name}...\nPlease wait.", position=(0,0,0))
clear_scene(server)
# 1. 寻找 PLY 文件
ply_files = sorted([f for f in scene_path.glob("frame_*.ply")], key=lambda x: int(x.stem.split("_")[-1]))
ply_files = ply_files[:MAX_FRAMES]
if not ply_files:
print("未找到 ply 文件")
loading_node.remove()
return
# 2. 加载轨迹数据并清洗
traj_path = scene_path / 'trajectory_all_pointmap.npy'
has_traj = traj_path.exists()
traj_raw = None
vis_mask_raw = None
if has_traj:
print("加载并清洗轨迹数据 (这可能需要一点时间)...")
try:
traj_raw = np.load(str(traj_path)) # [T, N, 3]
# 截断到最大帧数
traj_raw = traj_raw[:len(ply_files)]
vis_mask_raw = ~np.isnan(traj_raw).any(axis=-1)
# 对每一帧轨迹进行去噪 (复用你提供的逻辑)
# for i in range(traj_raw.shape[0]):
# pts = traj_raw[i][vis_mask_raw[i]]
# if pts.shape[0] == 0:
# continue
# # 使用 Open3D 去噪
# # _, ind = remove_radius_outlier_open3d(pts, nb_neighbors=30, std_ratio=2.0)
# # 更新 Mask
# mask = vis_mask_raw[i].copy()
# valid_indices = np.where(mask)[0]
# # filtered_indices = valid_indices[ind]
# new_mask = np.zeros_like(mask, dtype=bool)
# new_mask[filtered_indices] = True
# vis_mask_raw[i] = new_mask
print("轨迹数据处理完成")
except Exception as e:
print(f"轨迹加载失败: {e}")
has_traj = False
# 3. 加载点云帧
new_nodes = []
print(f"加载 {len(ply_files)} 帧点云...")
for i, ply_file in enumerate(ply_files):
try:
pcd = o3d.io.read_point_cloud(str(ply_file))
# 点云去噪
# pcd, _ = pcd.remove_radius_outlier(nb_points=10, radius=0.001)
# 坐标变换: Scale 100, Flip [1, -1, -1]
points = np.asarray(pcd.points) * 100 * np.array([1, -1, -1])
colors = np.asarray(pcd.colors)
node = server.scene.add_point_cloud(
name=f"/scene/frame_{i}",
points=points,
colors=(colors * 255).astype(np.uint8),
point_size=0.01,
point_shape="rounded",
visible=False
)
new_nodes.append(node)
except Exception as e:
print(f"帧 {i} 加载失败: {e}")
# 4. 创建轨迹线容器
line_node = None
if has_traj:
line_node = server.scene.add_line_segments(
name="/scene/trajectories",
points=np.zeros((0, 2, 3)),
colors=np.zeros((0, 2, 3), dtype=np.uint8),
line_width=gui_line_width_val,
visible=True
)
# 5. 更新全局状态
with state.lock:
state.point_nodes = new_nodes
state.num_frames = len(new_nodes)
state.trajectory_raw_all = traj_raw
state.visibility_mask_raw_all = vis_mask_raw
state.line_node = line_node
state.current_downsample = -1 # 强制刷新
state.is_loaded = True
loading_node.remove()
print(f"场景加载完毕。")
# ==============================================================================
# 5. 主程序
# ==============================================================================
def main():
download_data()
print(f"启动 Viser 服务器,端口: {SERVER_PORT}...")
server = viser.ViserServer(host="0.0.0.0", port=SERVER_PORT)
if not DATA_ROOT.exists():
server.scene.add_label("err", "Error: Data folder empty", position=(0,0,0))
while True: time.sleep(1)
# 过滤场景文件夹
scene_names = sorted([
d.name for d in DATA_ROOT.iterdir()
if d.is_dir() and not d.name.startswith(".")
])
if not scene_names:
server.scene.add_label("err", "No scenes found", position=(0,0,0))
while True: time.sleep(1)
# --- GUI 设置 ---
with server.gui.add_folder("Scene Control"):
gui_scene_select = server.gui.add_dropdown("Select Scene", options=scene_names, initial_value=scene_names[0])
with server.gui.add_folder("Playback Controls"):
gui_playing = server.gui.add_checkbox("Playing", True)
gui_framerate = server.gui.add_slider("FPS", min=1, max=60, step=0.1, initial_value=24)
gui_timestep = server.gui.add_slider("Timestep", min=0, max=100, step=1, initial_value=0)
gui_point_size = server.gui.add_slider("Point size", min=0.001, max=0.05, step=0.001, initial_value=0.01)
gui_line_width = server.gui.add_slider("Line width", min=0.1, max=5.0, step=0.1, initial_value=0.5)
gui_max_traj_length = server.gui.add_slider("Trail Length", min=1, max=50, step=1, initial_value=5)
gui_downsample = server.gui.add_slider("Downsample", min=1, max=100, step=1, initial_value=100)
gui_vis_mode = server.gui.add_button_group("Vis Mode", ("PointCloud", "Tracking", "Both"))
gui_vis_mode.value = "Both"
# --- 回调 ---
@gui_scene_select.on_update
def _(_):
load_scene_data(server, gui_scene_select.value, gui_line_width.value)
if state.num_frames > 0:
gui_timestep.max = state.num_frames - 1
gui_timestep.value = 0
# 初始加载
load_scene_data(server, scene_names[0], gui_line_width.value)
gui_timestep.max = max(1, state.num_frames - 1)
# --- 渲染循环 ---
prev_timestep = -1
while True:
if not state.is_loaded or state.num_frames == 0:
time.sleep(0.1)
continue
# 1. 播放控制
if gui_playing.value:
next_step = (gui_timestep.value + 1) % state.num_frames
gui_timestep.value = next_step
t_curr = gui_timestep.value
# 2. 处理降采样变更 (如果改变了,重新计算颜色和缓存)
with state.lock:
if state.trajectory_raw_all is not None and gui_downsample.value != state.current_downsample:
state.current_downsample = gui_downsample.value
# 切片数据
state.trajectories_3d = state.trajectory_raw_all[:, ::state.current_downsample]
state.visibility_mask = state.visibility_mask_raw_all[:, ::state.current_downsample]
# 计算颜色
state.initial_colors = compute_trajectory_colors(state.trajectories_3d, state.visibility_mask)
# 清空历史
state.history_lines_pos = []
state.history_lines_col = []
prev_timestep = -1 # 强制刷新
# 3. 可视化模式判断
show_points = gui_vis_mode.value in ("PointCloud", "Both")
show_lines = gui_vis_mode.value in ("Tracking", "Both") and (state.line_node is not None)
# 4. 更新点云 (切换帧)
if t_curr != prev_timestep:
# 隐藏上一帧
if 0 <= prev_timestep < len(state.point_nodes):
state.point_nodes[prev_timestep].visible = False
# 显示当前帧
if 0 <= t_curr < len(state.point_nodes):
node = state.point_nodes[t_curr]
if show_points:
node.visible = True
node.point_size = gui_point_size.value
else:
node.visible = False
# 5. 更新轨迹 (拖尾效果)
if show_lines:
state.line_node.visible = True
state.line_node.line_width = gui_line_width.value
# 如果回绕到0,清空历史
if t_curr == 0:
state.history_lines_pos = []
state.history_lines_col = []
state.line_node.points = np.zeros((0, 2, 3))
# 如果时间前进,计算新线段
elif t_curr > 0 and t_curr < state.num_frames:
t_prev = t_curr - 1
# 获取位置并应用坐标变换 (Scale 100, Flip)
# 注意:这里必须和点云的变换一致
pos_prev = state.trajectories_3d[t_prev] * 100 #* np.array([1, -1, -1])
pos_curr = state.trajectories_3d[t_curr] * 100 #* np.array([1, -1, -1])
# 检查可见性
valid_mask = state.visibility_mask[t_prev] & state.visibility_mask[t_curr]
if np.any(valid_mask):
p1 = pos_prev[valid_mask]
p2 = pos_curr[valid_mask]
# 过滤大跳变 (Teleportation artifacts)
dist = np.linalg.norm(p2 - p1, axis=1)
jump_mask = dist < 1.0 # 阈值
if np.any(jump_mask):
final_p1 = p1[jump_mask]
final_p2 = p2[jump_mask]
segments = np.stack([final_p1, final_p2], axis=1)
# 颜色
cols = state.initial_colors[valid_mask][jump_mask]
segment_colors = np.stack([cols, cols], axis=1)
state.history_lines_pos.append(segments)
state.history_lines_col.append(segment_colors)
# 维持拖尾长度
max_len = int(gui_max_traj_length.value)
while len(state.history_lines_pos) > max_len:
state.history_lines_pos.pop(0)
state.history_lines_col.pop(0)
# 更新 Viser 节点
if state.history_lines_pos:
all_pos = np.concatenate(state.history_lines_pos, axis=0)
all_col = np.concatenate(state.history_lines_col, axis=0)
state.line_node.points = all_pos
state.line_node.colors = all_col
else:
state.line_node.points = np.zeros((0, 2, 3))
elif state.line_node:
state.line_node.visible = False
prev_timestep = t_curr
time.sleep(1.0 / gui_framerate.value)
if __name__ == "__main__":
main()