|
|
|
|
|
""" |
|
|
WaveGen 训练结果可视化工具 (独立版本) |
|
|
自动检索 core_space 目录并可视化训练输出 |
|
|
|
|
|
Usage: |
|
|
cd code/WaveGen/nano_WaveGen |
|
|
python utils/visualize_training.py |
|
|
""" |
|
|
|
|
|
import numpy as np |
|
|
import viser |
|
|
import viser.transforms as viser_tf |
|
|
from typing import Optional, Dict, List, Tuple, Any |
|
|
import os |
|
|
from pathlib import Path |
|
|
import json |
|
|
import cv2 |
|
|
import time |
|
|
import webbrowser |
|
|
from scipy.spatial.transform import Rotation |
|
|
import threading |
|
|
|
|
|
|
|
|
try: |
|
|
from depth_to_pointcloud import DepthToPointCloud |
|
|
except ImportError: |
|
|
|
|
|
import sys |
|
|
sys.path.append(str(Path(__file__).parent)) |
|
|
from depth_to_pointcloud import DepthToPointCloud |
|
|
|
|
|
|
|
|
class TrainingVisualizer: |
|
|
"""WaveGen训练结果可视化器""" |
|
|
|
|
|
def __init__(self, core_space_dir: str = "core_space", port: int = 8080): |
|
|
""" |
|
|
初始化可视化器 |
|
|
|
|
|
Args: |
|
|
core_space_dir: core_space目录路径(相对于当前工作目录) |
|
|
port: 起始端口号(如果占用会自动尝试下一个) |
|
|
""" |
|
|
self.core_space_dir = Path(core_space_dir) |
|
|
if not self.core_space_dir.is_absolute(): |
|
|
self.core_space_dir = Path.cwd() / self.core_space_dir |
|
|
|
|
|
|
|
|
self.server = None |
|
|
self.port = port |
|
|
max_attempts = 10 |
|
|
|
|
|
for attempt in range(max_attempts): |
|
|
try_port = port + attempt |
|
|
try: |
|
|
|
|
|
self.server = viser.ViserServer(port=try_port, show_config=False) |
|
|
self.port = try_port |
|
|
print(f"🌐 Viser服务器已启动: http://localhost:{try_port}") |
|
|
if attempt > 0: |
|
|
print(f" (端口 {port} 被占用,自动使用端口 {try_port})") |
|
|
break |
|
|
except OSError as e: |
|
|
if "Address already in use" in str(e): |
|
|
if attempt == max_attempts - 1: |
|
|
print(f"❌ 无法找到可用端口 (尝试了 {port}-{try_port})") |
|
|
print(f" 请手动关闭其他实例: pkill -f visualize_training.py") |
|
|
raise |
|
|
continue |
|
|
else: |
|
|
raise |
|
|
|
|
|
|
|
|
self.superquadric_handles = [] |
|
|
self.gt_superquadric_handles = [] |
|
|
self.camera_handles = [] |
|
|
self.camera_frustum_handles = [] |
|
|
self.point_cloud_handle = None |
|
|
self.camera_rgb_handle = None |
|
|
self.coordinate_frame_handle = None |
|
|
self.mesh_handles_pool = {} |
|
|
self.object_label_handles = [] |
|
|
|
|
|
|
|
|
self.predictions_npz = None |
|
|
self.targets_npz = None |
|
|
self.current_sample_path = None |
|
|
self.current_frame = 0 |
|
|
self.original_frame_count = 0 |
|
|
self.scene_center = np.array([0, 0, 0]) |
|
|
self.scene_scale = 1.0 |
|
|
|
|
|
|
|
|
self.gui_controls = {} |
|
|
|
|
|
|
|
|
self.is_playing = False |
|
|
|
|
|
|
|
|
self.is_exporting = False |
|
|
self.export_progress = 0 |
|
|
self.export_camera_pos = None |
|
|
self.export_camera_wxyz = None |
|
|
|
|
|
|
|
|
self.setup_scene() |
|
|
|
|
|
|
|
|
self.scan_training_outputs() |
|
|
|
|
|
|
|
|
self.setup_gui() |
|
|
|
|
|
print("✅ 训练可视化器已初始化") |
|
|
print(f"📁 监控目录: {self.core_space_dir}") |
|
|
|
|
|
if len(self.training_outputs) == 0: |
|
|
print("⚠️ 未找到训练输出,请检查 core_space 目录") |
|
|
|
|
|
def setup_scene(self): |
|
|
"""设置场景背景和坐标系""" |
|
|
|
|
|
self.update_background(wireframe_mode=False) |
|
|
|
|
|
|
|
|
self.server.scene.set_up_direction("+y") |
|
|
|
|
|
def update_background(self, wireframe_mode: bool): |
|
|
"""更新场景背景颜色""" |
|
|
if wireframe_mode: |
|
|
|
|
|
bg_color = [0, 0, 0] |
|
|
else: |
|
|
|
|
|
bg_color = [13, 13, 38] |
|
|
|
|
|
width, height = 1920, 1080 |
|
|
solid_color_image = np.full((height, width, 3), bg_color, dtype=np.uint8) |
|
|
self.server.scene.set_background_image(solid_color_image, format="png") |
|
|
|
|
|
def scan_training_outputs(self): |
|
|
"""扫描core_space目录下的训练输出""" |
|
|
self.training_outputs = [] |
|
|
|
|
|
if not self.core_space_dir.exists(): |
|
|
print(f"⚠️ core_space目录不存在: {self.core_space_dir}") |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
for output_dir in sorted(self.core_space_dir.glob("*_text2wave"), reverse=True): |
|
|
if output_dir.is_dir(): |
|
|
|
|
|
sample_dirs = sorted(output_dir.glob("sample_*")) |
|
|
if sample_dirs: |
|
|
self.training_outputs.append({ |
|
|
'path': output_dir, |
|
|
'name': output_dir.name, |
|
|
'samples': len(sample_dirs) |
|
|
}) |
|
|
|
|
|
print(f"📦 找到 {len(self.training_outputs)} 个训练输出") |
|
|
for output in self.training_outputs: |
|
|
print(f" - {output['name']} ({output['samples']} 样本)") |
|
|
|
|
|
def setup_gui(self): |
|
|
"""设置GUI控件""" |
|
|
|
|
|
with self.server.gui.add_folder("训练输出"): |
|
|
if self.training_outputs: |
|
|
output_names = [out['name'] for out in self.training_outputs] |
|
|
self.gui_controls['output_selector'] = self.server.gui.add_dropdown( |
|
|
"选择训练输出", |
|
|
options=output_names, |
|
|
initial_value=output_names[0] |
|
|
) |
|
|
self.gui_controls['output_selector'].on_update(self._on_output_change) |
|
|
|
|
|
|
|
|
self.gui_controls['sample_slider'] = self.server.gui.add_slider( |
|
|
"样本索引", |
|
|
min=0, |
|
|
max=max(0, self.training_outputs[0]['samples'] - 1), |
|
|
step=1, |
|
|
initial_value=0 |
|
|
) |
|
|
self.gui_controls['sample_slider'].on_update(self._on_sample_change) |
|
|
|
|
|
self.gui_controls['load_button'] = self.server.gui.add_button("加载样本") |
|
|
self.gui_controls['load_button'].on_click(self._on_load_sample) |
|
|
else: |
|
|
self.server.gui.add_text("状态", initial_value="未找到训练输出") |
|
|
|
|
|
|
|
|
with self.server.gui.add_folder("帧控制"): |
|
|
self.gui_controls['frame_slider'] = self.server.gui.add_slider( |
|
|
"当前帧", |
|
|
min=0, |
|
|
max=23, |
|
|
step=1, |
|
|
initial_value=0 |
|
|
) |
|
|
self.gui_controls['frame_slider'].on_update(self._on_frame_change) |
|
|
|
|
|
self.gui_controls['play_button'] = self.server.gui.add_button("▶ 播放") |
|
|
self.gui_controls['play_button'].on_click(self._on_play) |
|
|
|
|
|
self.gui_controls['pause_button'] = self.server.gui.add_button("⏸ 暂停") |
|
|
self.gui_controls['pause_button'].on_click(self._on_pause) |
|
|
|
|
|
self.gui_controls['fps_slider'] = self.server.gui.add_slider( |
|
|
"播放FPS", |
|
|
min=1, |
|
|
max=30, |
|
|
step=1, |
|
|
initial_value=8 |
|
|
) |
|
|
|
|
|
|
|
|
with self.server.gui.add_folder("生成结果"): |
|
|
self.gui_controls['show_generated'] = self.server.gui.add_checkbox( |
|
|
"显示生成的超二次曲面", initial_value=True |
|
|
) |
|
|
self.gui_controls['show_generated'].on_update(self._on_visibility_change) |
|
|
|
|
|
self.gui_controls['generated_opacity'] = self.server.gui.add_slider( |
|
|
"生成结果透明度", min=0.1, max=1.0, step=0.05, initial_value=0.7 |
|
|
) |
|
|
self.gui_controls['generated_opacity'].on_update(self._on_opacity_change) |
|
|
|
|
|
self.gui_controls['generated_color'] = self.server.gui.add_rgb( |
|
|
"生成结果颜色", initial_value=(100, 149, 237) |
|
|
) |
|
|
self.gui_controls['generated_color'].on_update(self._on_color_change) |
|
|
|
|
|
|
|
|
with self.server.gui.add_folder("Ground Truth"): |
|
|
self.gui_controls['show_gt'] = self.server.gui.add_checkbox( |
|
|
"显示GT超二次曲面", initial_value=True |
|
|
) |
|
|
self.gui_controls['show_gt'].on_update(self._on_visibility_change) |
|
|
|
|
|
self.gui_controls['gt_opacity'] = self.server.gui.add_slider( |
|
|
"GT透明度", min=0.1, max=1.0, step=0.05, initial_value=0.5 |
|
|
) |
|
|
self.gui_controls['gt_opacity'].on_update(self._on_opacity_change) |
|
|
|
|
|
self.gui_controls['gt_color'] = self.server.gui.add_rgb( |
|
|
"GT颜色", initial_value=(255, 99, 71) |
|
|
) |
|
|
self.gui_controls['gt_color'].on_update(self._on_color_change) |
|
|
|
|
|
self.gui_controls['show_object_info'] = self.server.gui.add_checkbox( |
|
|
"显示物体信息", initial_value=False |
|
|
) |
|
|
self.gui_controls['show_object_info'].on_update(self._on_visibility_change) |
|
|
|
|
|
|
|
|
with self.server.gui.add_folder("点云显示"): |
|
|
self.gui_controls['show_pointcloud'] = self.server.gui.add_checkbox( |
|
|
"显示点云", initial_value=True |
|
|
) |
|
|
self.gui_controls['show_pointcloud'].on_update(self._on_visibility_change) |
|
|
|
|
|
self.gui_controls['pointcloud_size'] = self.server.gui.add_slider( |
|
|
"点大小", min=0.001, max=0.02, step=0.001, initial_value=0.008 |
|
|
) |
|
|
self.gui_controls['pointcloud_size'].on_update(self._on_visibility_change) |
|
|
|
|
|
|
|
|
with self.server.gui.add_folder("渲染设置"): |
|
|
self.gui_controls['mesh_resolution'] = self.server.gui.add_slider( |
|
|
"网格分辨率", min=10, max=50, step=5, initial_value=25 |
|
|
) |
|
|
self.gui_controls['mesh_resolution'].on_update(self._on_mesh_resolution_change) |
|
|
|
|
|
self.gui_controls['show_coordinate'] = self.server.gui.add_checkbox( |
|
|
"显示坐标系", initial_value=False |
|
|
) |
|
|
self.gui_controls['show_coordinate'].on_update(self._on_visibility_change) |
|
|
|
|
|
self.gui_controls['wireframe_mode'] = self.server.gui.add_checkbox( |
|
|
"线框模式 (黑白边缘)", initial_value=False |
|
|
) |
|
|
self.gui_controls['wireframe_mode'].on_update(self._on_wireframe_mode_change) |
|
|
|
|
|
|
|
|
with self.server.gui.add_folder("相机控制"): |
|
|
self.gui_controls['reset_view'] = self.server.gui.add_button("重置视角") |
|
|
self.gui_controls['reset_view'].on_click(self._on_reset_view) |
|
|
|
|
|
self.gui_controls['match_camera'] = self.server.gui.add_button("匹配GT相机") |
|
|
self.gui_controls['match_camera'].on_click(self._on_match_camera) |
|
|
|
|
|
self.gui_controls['show_target_frustum'] = self.server.gui.add_checkbox( |
|
|
"显示GT相机椎体", initial_value=True |
|
|
) |
|
|
self.gui_controls['show_pred_frustum'] = self.server.gui.add_checkbox( |
|
|
"显示预测相机椎体", initial_value=True |
|
|
) |
|
|
self.gui_controls['show_camera_rgb'] = self.server.gui.add_checkbox( |
|
|
"相机视锥显示RGB", initial_value=True |
|
|
) |
|
|
self.gui_controls['show_target_frustum'].on_update(self._on_visibility_change) |
|
|
self.gui_controls['show_pred_frustum'].on_update(self._on_visibility_change) |
|
|
self.gui_controls['show_camera_rgb'].on_update(self._on_visibility_change) |
|
|
|
|
|
|
|
|
with self.server.gui.add_folder("视频导出"): |
|
|
self.gui_controls['export_status'] = self.server.gui.add_text( |
|
|
"状态", initial_value="就绪" |
|
|
) |
|
|
|
|
|
self.gui_controls['export_resolution'] = self.server.gui.add_slider( |
|
|
"导出分辨率", min=480, max=1080, step=120, initial_value=720 |
|
|
) |
|
|
|
|
|
self.gui_controls['capture_camera_button'] = self.server.gui.add_button( |
|
|
"📸 捕获当前视角" |
|
|
) |
|
|
self.gui_controls['capture_camera_button'].on_click(self._on_capture_camera) |
|
|
|
|
|
self.gui_controls['export_viser_button'] = self.server.gui.add_button( |
|
|
"💾 导出场景(.viser)" |
|
|
) |
|
|
self.gui_controls['export_viser_button'].on_click(self._on_export_viser) |
|
|
|
|
|
self.gui_controls['export_button'] = self.server.gui.add_button("🎬 导出视频(MP4)") |
|
|
self.gui_controls['export_button'].on_click(self._on_export_video) |
|
|
|
|
|
print(f"✅ GUI 已设置 - 创建了 {len(self.gui_controls)} 个控件") |
|
|
|
|
|
def _on_output_change(self, event): |
|
|
"""训练输出选择改变""" |
|
|
selected_name = event.target.value |
|
|
for i, output in enumerate(self.training_outputs): |
|
|
if output['name'] == selected_name: |
|
|
|
|
|
max_sample = max(0, output['samples'] - 1) |
|
|
self.gui_controls['sample_slider'].max = max_sample |
|
|
self.gui_controls['sample_slider'].value = 0 |
|
|
break |
|
|
|
|
|
def _on_sample_change(self, event): |
|
|
"""样本索引改变""" |
|
|
pass |
|
|
|
|
|
def _on_load_sample(self, event): |
|
|
"""加载选中的样本""" |
|
|
selected_name = self.gui_controls['output_selector'].value |
|
|
sample_idx = int(self.gui_controls['sample_slider'].value) |
|
|
|
|
|
|
|
|
output_path = None |
|
|
for output in self.training_outputs: |
|
|
if output['name'] == selected_name: |
|
|
output_path = output['path'] |
|
|
break |
|
|
|
|
|
if output_path is None: |
|
|
print(f"❌ 未找到训练输出: {selected_name}") |
|
|
return |
|
|
|
|
|
self.load_sample(output_path, sample_idx) |
|
|
|
|
|
def load_sample(self, output_path: Path, sample_idx: int): |
|
|
"""加载样本数据""" |
|
|
sample_path = output_path / f"sample_{sample_idx}" |
|
|
|
|
|
if not sample_path.exists(): |
|
|
print(f"❌ 样本目录不存在: {sample_path}") |
|
|
return |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print(f"📂 加载样本: {output_path.name}/sample_{sample_idx}") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
self.current_sample_path = sample_path |
|
|
|
|
|
|
|
|
pred_file = sample_path / "predictions.npz" |
|
|
if pred_file.exists(): |
|
|
npz_data = np.load(pred_file, allow_pickle=True) |
|
|
self.predictions_npz = {key: npz_data[key] for key in npz_data.files} |
|
|
npz_data.close() |
|
|
print(f"✅ 加载predictions.npz: {pred_file}") |
|
|
if 'frames' in self.predictions_npz: |
|
|
print(f" 帧数: {len(self.predictions_npz['frames'])}") |
|
|
if 'text' in self.predictions_npz: |
|
|
print(f" 文本: {self.predictions_npz['text']}") |
|
|
else: |
|
|
self.predictions_npz = None |
|
|
print(f"⚠️ 未找到predictions.npz") |
|
|
|
|
|
|
|
|
target_file = sample_path / "targets.npz" |
|
|
if target_file.exists(): |
|
|
npz_data = np.load(target_file, allow_pickle=True) |
|
|
self.targets_npz = {key: npz_data[key] for key in npz_data.files} |
|
|
npz_data.close() |
|
|
print(f"✅ 加载targets.npz: {target_file}") |
|
|
if 'frames' in self.targets_npz: |
|
|
print(f" 帧数: {len(self.targets_npz['frames'])}") |
|
|
if 'text' in self.targets_npz: |
|
|
print(f" 文本: {self.targets_npz['text']}") |
|
|
else: |
|
|
self.targets_npz = None |
|
|
print(f"⚠️ 未找到targets.npz") |
|
|
|
|
|
|
|
|
self.original_frame_count = 0 |
|
|
if self.predictions_npz and 'frames' in self.predictions_npz: |
|
|
self.original_frame_count = len(self.predictions_npz['frames']) |
|
|
elif self.targets_npz and 'objects' in self.targets_npz: |
|
|
objects = self.targets_npz['objects'] |
|
|
if hasattr(objects, 'shape') and len(objects.shape) >= 1: |
|
|
self.original_frame_count = objects.shape[0] |
|
|
|
|
|
if self.original_frame_count > 0: |
|
|
self.gui_controls['frame_slider'].max = self.original_frame_count - 1 |
|
|
self.gui_controls['frame_slider'].value = 0 |
|
|
self.current_frame = 0 |
|
|
print(f"📊 总帧数: {self.original_frame_count}") |
|
|
|
|
|
|
|
|
self.visualize_frame(0) |
|
|
|
|
|
def _on_frame_change(self, event): |
|
|
"""帧滑块改变""" |
|
|
frame_idx = int(event.target.value) |
|
|
self.visualize_frame(frame_idx) |
|
|
|
|
|
def _on_play(self, event): |
|
|
"""开始播放""" |
|
|
self.is_playing = True |
|
|
print("▶ 开始播放") |
|
|
|
|
|
|
|
|
import threading |
|
|
threading.Thread(target=self._playback_loop, daemon=True).start() |
|
|
|
|
|
def _on_pause(self, event): |
|
|
"""暂停播放""" |
|
|
self.is_playing = False |
|
|
print("⏸ 暂停播放") |
|
|
|
|
|
def _playback_loop(self): |
|
|
"""播放循环""" |
|
|
while self.is_playing: |
|
|
current_frame = int(self.gui_controls['frame_slider'].value) |
|
|
next_frame = (current_frame + 1) % self.original_frame_count |
|
|
|
|
|
self.gui_controls['frame_slider'].value = next_frame |
|
|
self.visualize_frame(next_frame) |
|
|
|
|
|
fps = int(self.gui_controls['fps_slider'].value) |
|
|
time.sleep(1.0 / fps) |
|
|
|
|
|
def _on_visibility_change(self, event): |
|
|
"""可见性改变""" |
|
|
self.visualize_frame(self.current_frame) |
|
|
|
|
|
def _on_opacity_change(self, event): |
|
|
"""透明度改变""" |
|
|
self.visualize_frame(self.current_frame) |
|
|
|
|
|
def _on_color_change(self, event): |
|
|
"""颜色改变""" |
|
|
self.visualize_frame(self.current_frame) |
|
|
|
|
|
def _on_mesh_resolution_change(self, event): |
|
|
"""网格分辨率改变""" |
|
|
|
|
|
for mesh in self.mesh_handles_pool.values(): |
|
|
mesh.remove() |
|
|
self.mesh_handles_pool.clear() |
|
|
self.visualize_frame(self.current_frame) |
|
|
|
|
|
def _on_wireframe_mode_change(self, event): |
|
|
"""线框模式改变""" |
|
|
wireframe_mode = event.target.value |
|
|
|
|
|
|
|
|
self.update_background(wireframe_mode) |
|
|
|
|
|
|
|
|
for mesh in self.mesh_handles_pool.values(): |
|
|
mesh.remove() |
|
|
self.mesh_handles_pool.clear() |
|
|
|
|
|
|
|
|
self.visualize_frame(self.current_frame) |
|
|
|
|
|
def _on_reset_view(self, event): |
|
|
"""重置视角""" |
|
|
|
|
|
for client in self.server.get_clients().values(): |
|
|
client.camera.position = (3.0, 2.0, 3.0) |
|
|
client.camera.look_at = (0.0, 0.0, 0.0) |
|
|
|
|
|
def _on_match_camera(self, event): |
|
|
"""匹配GT相机视角 (新格式)""" |
|
|
if self.targets_npz is None or 'frames' not in self.targets_npz: |
|
|
print("⚠️ 没有GT相机数据") |
|
|
return |
|
|
|
|
|
frame_idx = self.current_frame |
|
|
frames = self.targets_npz['frames'] |
|
|
|
|
|
if frame_idx >= len(frames): |
|
|
print("⚠️ 帧索引超出范围") |
|
|
return |
|
|
|
|
|
frame_data = frames[frame_idx] |
|
|
|
|
|
|
|
|
if isinstance(frame_data, np.ndarray): |
|
|
frame_data = frame_data.item() |
|
|
|
|
|
if 'world_info' not in frame_data: |
|
|
print("⚠️ 未找到world_info数据") |
|
|
return |
|
|
|
|
|
world_info = frame_data['world_info'] |
|
|
camera_position = world_info['camera_position'] |
|
|
|
|
|
q_xyzw = np.array(world_info['camera_quaternion'], dtype=np.float32) |
|
|
wxyz = (float(q_xyzw[3]), float(q_xyzw[0]), float(q_xyzw[1]), float(q_xyzw[2])) |
|
|
|
|
|
|
|
|
cam_pos_vis = (np.array(camera_position) - self.scene_center) * self.scene_scale |
|
|
|
|
|
print(f"📷 匹配相机: pos={camera_position}, quat={wxyz}") |
|
|
|
|
|
|
|
|
for client in self.server.get_clients().values(): |
|
|
client.camera.position = tuple(cam_pos_vis) |
|
|
client.camera.wxyz = wxyz |
|
|
|
|
|
def visualize_frame(self, frame_idx: int): |
|
|
"""可视化指定帧""" |
|
|
if self.original_frame_count <= 0: |
|
|
return |
|
|
|
|
|
frame_idx = int(np.clip(frame_idx, 0, self.original_frame_count - 1)) |
|
|
self.current_frame = frame_idx |
|
|
|
|
|
print(f"\n🎨 可视化帧 {frame_idx}/{self.original_frame_count-1}") |
|
|
|
|
|
|
|
|
self.clear_visualization() |
|
|
|
|
|
|
|
|
show_generated = self.gui_controls['show_generated'].value |
|
|
show_gt = self.gui_controls['show_gt'].value |
|
|
show_pointcloud = self.gui_controls['show_pointcloud'].value |
|
|
show_coordinate = self.gui_controls['show_coordinate'].value |
|
|
|
|
|
generated_opacity = self.gui_controls['generated_opacity'].value |
|
|
gt_opacity = self.gui_controls['gt_opacity'].value |
|
|
generated_color = tuple(self.gui_controls['generated_color'].value) |
|
|
gt_color = tuple(self.gui_controls['gt_color'].value) |
|
|
mesh_resolution = int(self.gui_controls['mesh_resolution'].value) |
|
|
|
|
|
|
|
|
predictions = self._extract_predictions(frame_idx) |
|
|
targets = self._extract_targets(frame_idx) |
|
|
|
|
|
|
|
|
self.scene_center = np.zeros(3, dtype=np.float32) |
|
|
self.scene_scale = 1.0 |
|
|
|
|
|
norm_path = None |
|
|
if self.current_sample_path is not None: |
|
|
norm_path = self.current_sample_path / "original_data" / "scene_normalization.json" |
|
|
loaded_norm = False |
|
|
if norm_path is not None and norm_path.exists(): |
|
|
try: |
|
|
with open(norm_path) as f: |
|
|
norm = json.load(f) |
|
|
if 'scene_center' in norm: |
|
|
self.scene_center = np.array(norm['scene_center'], dtype=np.float32) |
|
|
if 'scene_scale' in norm: |
|
|
self.scene_scale = float(norm['scene_scale']) |
|
|
elif 'scene_extent' in norm and norm['scene_extent']: |
|
|
self.scene_scale = 20.0 / float(norm['scene_extent']) |
|
|
loaded_norm = True |
|
|
except Exception: |
|
|
loaded_norm = False |
|
|
|
|
|
if not loaded_norm: |
|
|
wi = self._get_world_info(frame_idx, source="targets") |
|
|
if wi is not None: |
|
|
if 'scene_center' in wi: |
|
|
self.scene_center = np.array(wi['scene_center'], dtype=np.float32) |
|
|
if 'scene_scale' in wi: |
|
|
try: |
|
|
self.scene_scale = float(wi['scene_scale']) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
wireframe_mode = self.gui_controls.get('wireframe_mode', None) |
|
|
is_wireframe = wireframe_mode.value if wireframe_mode else False |
|
|
|
|
|
if show_pointcloud and not is_wireframe: |
|
|
|
|
|
self._visualize_pointcloud(frame_idx, scene_center=self.scene_center, scene_scale=self.scene_scale) |
|
|
|
|
|
|
|
|
if show_generated and predictions is not None: |
|
|
self._visualize_superquadrics( |
|
|
predictions, |
|
|
color=generated_color, |
|
|
opacity=generated_opacity, |
|
|
mesh_resolution=mesh_resolution, |
|
|
is_gt=False |
|
|
) |
|
|
|
|
|
|
|
|
if show_gt and targets is not None: |
|
|
self._visualize_superquadrics( |
|
|
targets, |
|
|
color=gt_color, |
|
|
opacity=gt_opacity, |
|
|
mesh_resolution=mesh_resolution, |
|
|
is_gt=True |
|
|
) |
|
|
|
|
|
|
|
|
show_info = self.gui_controls['show_object_info'].value |
|
|
if show_info and not is_wireframe: |
|
|
self._visualize_object_labels(frame_idx, targets, is_gt=True) |
|
|
|
|
|
|
|
|
if show_coordinate: |
|
|
self.coordinate_frame_handle = self.server.scene.add_frame( |
|
|
"/coordinate", |
|
|
wxyz=(1, 0, 0, 0), |
|
|
position=(0, 0, 0), |
|
|
axes_length=1.0, |
|
|
axes_radius=0.01 |
|
|
) |
|
|
|
|
|
|
|
|
if not is_wireframe: |
|
|
self._visualize_cameras(frame_idx) |
|
|
|
|
|
def _extract_predictions(self, frame_idx: int) -> Optional[np.ndarray]: |
|
|
"""提取预测数据 (新格式)""" |
|
|
if self.predictions_npz is None or 'frames' not in self.predictions_npz: |
|
|
return None |
|
|
|
|
|
frames = self.predictions_npz['frames'] |
|
|
if frame_idx >= len(frames): |
|
|
return None |
|
|
|
|
|
frame_data = frames[frame_idx] |
|
|
|
|
|
|
|
|
if isinstance(frame_data, np.ndarray): |
|
|
frame_data = frame_data.item() |
|
|
|
|
|
if 'superquadrics' not in frame_data: |
|
|
return None |
|
|
|
|
|
superquadrics = frame_data['superquadrics'] |
|
|
objects_array = [] |
|
|
|
|
|
for sq in superquadrics: |
|
|
|
|
|
obj_params = np.zeros(15, dtype=np.float32) |
|
|
obj_params[0] = 1.0 if sq['exists'] else 0.0 |
|
|
obj_params[1:3] = sq['shape'] |
|
|
obj_params[3:6] = sq['scale'] |
|
|
obj_params[6:9] = sq['translation'] |
|
|
obj_params[9:12] = sq['rotation'] |
|
|
obj_params[12:15] = sq['velocity'] |
|
|
objects_array.append(obj_params) |
|
|
|
|
|
return np.array(objects_array, dtype=np.float32) |
|
|
|
|
|
def _extract_targets(self, frame_idx: int) -> Optional[np.ndarray]: |
|
|
"""提取GT数据 (新格式)""" |
|
|
if self.targets_npz is None or 'frames' not in self.targets_npz: |
|
|
return None |
|
|
|
|
|
frames = self.targets_npz['frames'] |
|
|
if frame_idx >= len(frames): |
|
|
return None |
|
|
|
|
|
frame_data = frames[frame_idx] |
|
|
|
|
|
|
|
|
if isinstance(frame_data, np.ndarray): |
|
|
frame_data = frame_data.item() |
|
|
|
|
|
if 'superquadrics' not in frame_data: |
|
|
return None |
|
|
|
|
|
superquadrics = frame_data['superquadrics'] |
|
|
objects_array = [] |
|
|
|
|
|
for sq in superquadrics: |
|
|
|
|
|
obj_params = np.zeros(16, dtype=np.float32) |
|
|
obj_params[0] = 1.0 if sq['exists'] else 0.0 |
|
|
obj_params[1:3] = sq['shape'] |
|
|
obj_params[3:6] = sq['scale'] |
|
|
obj_params[6:9] = sq['translation'] |
|
|
obj_params[9:12] = sq['rotation'] |
|
|
obj_params[12] = sq['inlier_ratio'] |
|
|
obj_params[13:16] = sq['velocity'] |
|
|
objects_array.append(obj_params) |
|
|
|
|
|
return np.array(objects_array, dtype=np.float32) |
|
|
|
|
|
def _visualize_superquadrics(self, objects: np.ndarray, color: Tuple, |
|
|
opacity: float, mesh_resolution: int, is_gt: bool): |
|
|
"""可视化超二次曲面""" |
|
|
prefix = "gt" if is_gt else "gen" |
|
|
num_active = 0 |
|
|
|
|
|
for obj_idx, obj_params in enumerate(objects): |
|
|
|
|
|
if obj_params[0] > 0.5: |
|
|
num_active += 1 |
|
|
|
|
|
try: |
|
|
|
|
|
vertices, faces = self.generate_superquadric_mesh( |
|
|
obj_params, num_samples=mesh_resolution |
|
|
) |
|
|
|
|
|
|
|
|
mesh_key = f"{prefix}_{obj_idx}" |
|
|
mesh = self.get_or_create_mesh( |
|
|
mesh_key, vertices, faces, color, opacity |
|
|
) |
|
|
|
|
|
if is_gt: |
|
|
self.gt_superquadric_handles.append(mesh) |
|
|
else: |
|
|
self.superquadric_handles.append(mesh) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ 可视化对象{obj_idx}失败: {e}") |
|
|
|
|
|
label = "GT" if is_gt else "生成" |
|
|
print(f" {label}对象数: {num_active}") |
|
|
|
|
|
def _visualize_object_labels(self, frame_idx: int, objects: np.ndarray, is_gt: bool): |
|
|
"""在物体上显示信息标签""" |
|
|
|
|
|
if is_gt and self.targets_npz is not None and 'frames' in self.targets_npz: |
|
|
frames = self.targets_npz['frames'] |
|
|
if frame_idx >= len(frames): |
|
|
return |
|
|
|
|
|
frame_data = frames[frame_idx] |
|
|
if isinstance(frame_data, np.ndarray): |
|
|
frame_data = frame_data.item() |
|
|
|
|
|
if 'superquadrics' not in frame_data: |
|
|
return |
|
|
|
|
|
superquadrics = frame_data['superquadrics'] |
|
|
|
|
|
for obj_idx, sq in enumerate(superquadrics): |
|
|
if not sq['exists']: |
|
|
continue |
|
|
|
|
|
|
|
|
translation = sq['translation'] |
|
|
scale = sq['scale'] |
|
|
|
|
|
|
|
|
label_position = ( |
|
|
float(translation[0]), |
|
|
float(translation[1]) + float(scale[1]) * 1.5, |
|
|
float(translation[2]) |
|
|
) |
|
|
|
|
|
|
|
|
inlier_ratio = sq.get('inlier_ratio', 0.0) |
|
|
shape = sq.get('shape', [0, 0]) |
|
|
|
|
|
info_text = ( |
|
|
f"ID: {obj_idx}\n" |
|
|
f"Density: {inlier_ratio:.3f}\n" |
|
|
f"Shape: ε1={shape[0]:.2f}, ε2={shape[1]:.2f}\n" |
|
|
f"Size: {scale[0]:.2f}×{scale[1]:.2f}×{scale[2]:.2f}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
label_name = f"/object_label_f{frame_idx}_o{obj_idx}" |
|
|
try: |
|
|
label_handle = self.server.scene.add_label( |
|
|
label_name, |
|
|
text=info_text, |
|
|
position=label_position |
|
|
) |
|
|
self.object_label_handles.append(label_handle) |
|
|
except Exception as e: |
|
|
print(f"⚠️ 创建标签失败: {e}") |
|
|
|
|
|
def _visualize_pointcloud(self, frame_idx: int, scene_center: Optional[np.ndarray] = None, scene_scale: Optional[float] = None): |
|
|
"""可视化点云""" |
|
|
if self.current_sample_path is None: |
|
|
return |
|
|
|
|
|
|
|
|
original_data_dir = self.current_sample_path / "original_data" |
|
|
if not original_data_dir.exists(): |
|
|
print("⚠️ 未找到original_data目录") |
|
|
return |
|
|
|
|
|
|
|
|
depth_file = self._find_depth_file(original_data_dir, frame_idx) |
|
|
rgb_file = original_data_dir / "rgb" / f"frame_{frame_idx:03d}.png" |
|
|
|
|
|
if depth_file is None or not rgb_file.exists(): |
|
|
print(f"⚠️ 未找到帧{frame_idx}的深度图或RGB") |
|
|
return |
|
|
|
|
|
try: |
|
|
|
|
|
depth = self._load_depth(depth_file, frame_idx) |
|
|
if depth.ndim == 2: |
|
|
depth = depth[:, :, None] |
|
|
rgb = self._load_rgb(rgb_file) |
|
|
|
|
|
|
|
|
camera_K = None |
|
|
metadata_file = original_data_dir / "metadata.json" |
|
|
if metadata_file.exists(): |
|
|
with open(metadata_file) as f: |
|
|
metadata = json.load(f) |
|
|
if 'camera' in metadata and 'K' in metadata['camera']: |
|
|
camera_K = np.array(metadata['camera']['K'], dtype=np.float32) |
|
|
if camera_K is None: |
|
|
h, w = depth.shape[:2] |
|
|
camera_K = np.array([[w, 0, w/2], [0, h, h/2], [0, 0, 1]], dtype=np.float32) |
|
|
|
|
|
|
|
|
world_info = self._get_world_info(frame_idx, source="targets") |
|
|
camera_position = np.zeros(3, dtype=np.float32) |
|
|
camera_quat_xyzw = np.array([0, 0, 0, 1], dtype=np.float32) |
|
|
if world_info is not None and 'camera_position' in world_info: |
|
|
camera_position = np.array(world_info['camera_position'], dtype=np.float32) |
|
|
if 'camera_quaternion' in world_info: |
|
|
|
|
|
camera_quat_xyzw = np.array(world_info['camera_quaternion'], dtype=np.float32) |
|
|
|
|
|
|
|
|
converter = DepthToPointCloud() |
|
|
_, points_norm, _, depth_center, depth_extent = converter.depth_to_normalized_pointcloud_movi( |
|
|
depth=depth, |
|
|
segmentation=None, |
|
|
camera_K=camera_K, |
|
|
camera_position=camera_position, |
|
|
camera_quaternion=camera_quat_xyzw, |
|
|
resolution=depth.shape[0], |
|
|
convert_to_zdepth=True, |
|
|
scene_center_override=scene_center, |
|
|
scene_scale_override=scene_scale |
|
|
) |
|
|
|
|
|
valid_mask = depth[:, :, 0] > 0 |
|
|
points = points_norm[valid_mask] |
|
|
colors = rgb.reshape(-1, 3)[valid_mask.reshape(-1)] |
|
|
|
|
|
|
|
|
if scene_center is not None and scene_scale is not None: |
|
|
self.scene_center = np.array(scene_center, dtype=np.float32) |
|
|
self.scene_scale = float(scene_scale) |
|
|
else: |
|
|
self.scene_center = depth_center |
|
|
self.scene_scale = 20.0 / max(depth_extent, 1e-6) |
|
|
|
|
|
|
|
|
point_size = self.gui_controls['pointcloud_size'].value |
|
|
self.point_cloud_handle = self.server.scene.add_point_cloud( |
|
|
"/pointcloud", |
|
|
points=points, |
|
|
colors=colors, |
|
|
point_size=point_size |
|
|
) |
|
|
|
|
|
print(f" 点云: {len(points)} 个点") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ 加载点云失败: {e}") |
|
|
|
|
|
def _find_depth_file(self, original_data_dir: Path, frame_idx: int) -> Optional[Path]: |
|
|
"""查找深度文件(支持合并的npz和单独的npy)""" |
|
|
depth_dir = original_data_dir / "depth" |
|
|
if not depth_dir.exists(): |
|
|
return None |
|
|
|
|
|
|
|
|
merged_npz = depth_dir / "depth_merge.npz" |
|
|
if merged_npz.exists(): |
|
|
return merged_npz |
|
|
|
|
|
|
|
|
npy_file = depth_dir / f"frame_{frame_idx:03d}.npy" |
|
|
if npy_file.exists(): |
|
|
return npy_file |
|
|
|
|
|
return None |
|
|
|
|
|
def _load_depth(self, depth_file: Path, frame_idx: int) -> np.ndarray: |
|
|
"""加载深度数据""" |
|
|
if depth_file.suffix == '.npz': |
|
|
|
|
|
data = np.load(depth_file) |
|
|
frame_key = f"frame_{frame_idx:03d}" |
|
|
return data[frame_key] |
|
|
else: |
|
|
|
|
|
return np.load(depth_file) |
|
|
|
|
|
def _load_rgb(self, rgb_path: Path) -> np.ndarray: |
|
|
"""加载RGB图像""" |
|
|
img = cv2.imread(str(rgb_path)) |
|
|
if img is None: |
|
|
raise FileNotFoundError(f"Failed to load RGB image: {rgb_path}") |
|
|
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
def _get_world_info(self, frame_idx: int, source: str = "targets") -> Optional[Dict[str, np.ndarray]]: |
|
|
"""从pred/target获取世界/相机信息""" |
|
|
data = self.targets_npz if source == "targets" else self.predictions_npz |
|
|
if data is None: |
|
|
return None |
|
|
|
|
|
if 'frames' in data: |
|
|
frames = data['frames'] |
|
|
if frame_idx < len(frames): |
|
|
entry = frames[frame_idx] |
|
|
if hasattr(entry, 'item'): |
|
|
try: |
|
|
entry = entry.item() |
|
|
except Exception: |
|
|
pass |
|
|
if isinstance(entry, dict) and 'world_info' in entry: |
|
|
return entry['world_info'] |
|
|
|
|
|
|
|
|
if 'world' in data: |
|
|
world = data['world'] |
|
|
if hasattr(world, 'shape') and world.shape[0] > frame_idx and world.shape[-1] >= 7: |
|
|
wp = world[frame_idx] |
|
|
scene_center = world[frame_idx, 8:11] if world.shape[-1] >= 11 else np.zeros(3, dtype=np.float32) |
|
|
return { |
|
|
'camera_position': wp[:3], |
|
|
'camera_quaternion': wp[3:7], |
|
|
'scene_scale': float(wp[7]) if len(wp) > 7 else 1.0, |
|
|
'scene_center': scene_center, |
|
|
} |
|
|
return None |
|
|
|
|
|
def _visualize_cameras(self, frame_idx: int): |
|
|
"""可视化相机椎体与RGB""" |
|
|
for h in self.camera_frustum_handles: |
|
|
h.remove() |
|
|
self.camera_frustum_handles = [] |
|
|
if self.camera_rgb_handle is not None: |
|
|
self.camera_rgb_handle.remove() |
|
|
self.camera_rgb_handle = None |
|
|
|
|
|
show_target = self.gui_controls.get('show_target_frustum', None) |
|
|
show_pred = self.gui_controls.get('show_pred_frustum', None) |
|
|
show_rgb = self.gui_controls.get('show_camera_rgb', None) |
|
|
if show_target is None or show_pred is None or show_rgb is None: |
|
|
return |
|
|
if not (show_target.value or show_pred.value): |
|
|
return |
|
|
|
|
|
original_data_dir = None |
|
|
rgb_image = None |
|
|
if show_rgb.value and self.current_sample_path is not None: |
|
|
original_data_dir = self.current_sample_path / "original_data" |
|
|
if original_data_dir.exists(): |
|
|
rgb_path = original_data_dir / "rgb" / f"frame_{frame_idx:03d}.png" |
|
|
if rgb_path.exists(): |
|
|
try: |
|
|
rgb_image = self._load_rgb(rgb_path) |
|
|
except Exception: |
|
|
rgb_image = None |
|
|
|
|
|
|
|
|
fov = np.deg2rad(60.0) |
|
|
aspect = 1.0 |
|
|
if rgb_image is not None: |
|
|
h, w = rgb_image.shape[:2] |
|
|
aspect = w / max(h, 1) |
|
|
metadata_file = (self.current_sample_path / "original_data" / "metadata.json") if self.current_sample_path else None |
|
|
fx = None |
|
|
if metadata_file and metadata_file.exists(): |
|
|
try: |
|
|
with open(metadata_file) as f: |
|
|
metadata = json.load(f) |
|
|
if 'camera' in metadata and 'K' in metadata['camera']: |
|
|
K = np.array(metadata['camera']['K'], dtype=np.float32) |
|
|
fx = K[0, 0] |
|
|
except Exception: |
|
|
fx = None |
|
|
if fx is not None and w > 0: |
|
|
fov = 2 * np.arctan(w / (2 * fx)) |
|
|
|
|
|
def add_frustum(world_info: Dict, name: str, color: Tuple[int, int, int]): |
|
|
if world_info is None: |
|
|
return |
|
|
cam_pos = np.array(world_info.get('camera_position', np.zeros(3)), dtype=np.float32) |
|
|
cam_quat = np.array(world_info.get('camera_quaternion', [0, 0, 0, 1]), dtype=np.float32) |
|
|
if cam_quat.shape[0] == 4: |
|
|
wxyz = (float(cam_quat[3]), float(cam_quat[0]), float(cam_quat[1]), float(cam_quat[2])) |
|
|
else: |
|
|
wxyz = (1.0, 0.0, 0.0, 0.0) |
|
|
|
|
|
|
|
|
pos = (cam_pos - self.scene_center) * getattr(self, "scene_scale", 1.0) |
|
|
|
|
|
frustum = self.server.scene.add_camera_frustum( |
|
|
f"/{name}", |
|
|
fov=fov, |
|
|
aspect=aspect, |
|
|
scale=2.0, |
|
|
wxyz=wxyz, |
|
|
position=pos, |
|
|
image=rgb_image if show_rgb.value else None, |
|
|
color=tuple(int(c) for c in color) |
|
|
) |
|
|
self.camera_frustum_handles.append(frustum) |
|
|
|
|
|
if show_pred.value: |
|
|
add_frustum(self._get_world_info(frame_idx, source="predictions"), "pred_camera_frustum", (100, 149, 237)) |
|
|
if show_target.value: |
|
|
add_frustum(self._get_world_info(frame_idx, source="targets"), "gt_camera_frustum", (255, 99, 71)) |
|
|
|
|
|
def generate_superquadric_mesh(self, params, num_samples=25): |
|
|
"""生成超二次曲面mesh""" |
|
|
|
|
|
epsilon = [params[1], params[2]] |
|
|
scale = [params[3], params[4], params[5]] |
|
|
translation = [params[6], params[7], params[8]] |
|
|
rotation = [params[9], params[10], params[11]] if len(params) >= 12 else [0, 0, 0] |
|
|
|
|
|
|
|
|
eta = np.linspace(-np.pi/2, np.pi/2, num_samples) |
|
|
omega = np.linspace(-np.pi, np.pi, num_samples) |
|
|
|
|
|
vertices = [] |
|
|
faces = [] |
|
|
|
|
|
|
|
|
rot = Rotation.from_euler('ZYX', rotation) |
|
|
rot_matrix = rot.as_matrix() |
|
|
|
|
|
|
|
|
for i, e in enumerate(eta): |
|
|
for j, w in enumerate(omega): |
|
|
|
|
|
cos_eta = np.sign(np.cos(e)) * np.abs(np.cos(e))**epsilon[0] |
|
|
sin_eta = np.sign(np.sin(e)) * np.abs(np.sin(e))**epsilon[0] |
|
|
cos_omega = np.sign(np.cos(w)) * np.abs(np.cos(w))**epsilon[1] |
|
|
sin_omega = np.sign(np.sin(w)) * np.abs(np.sin(w))**epsilon[1] |
|
|
|
|
|
|
|
|
x_local = scale[0] * cos_eta * cos_omega |
|
|
y_local = scale[1] * cos_eta * sin_omega |
|
|
z_local = scale[2] * sin_eta |
|
|
|
|
|
|
|
|
point_local = np.array([x_local, y_local, z_local]) |
|
|
point_global = rot_matrix @ point_local + np.array(translation) |
|
|
|
|
|
vertices.append(point_global) |
|
|
|
|
|
vertices = np.array(vertices) |
|
|
|
|
|
|
|
|
for i in range(num_samples - 1): |
|
|
for j in range(num_samples - 1): |
|
|
idx1 = i * num_samples + j |
|
|
idx2 = i * num_samples + (j + 1) % num_samples |
|
|
idx3 = (i + 1) * num_samples + j |
|
|
idx4 = (i + 1) * num_samples + (j + 1) % num_samples |
|
|
|
|
|
faces.append([idx1, idx2, idx3]) |
|
|
faces.append([idx2, idx4, idx3]) |
|
|
|
|
|
return vertices, np.array(faces) |
|
|
|
|
|
def get_or_create_mesh(self, key: str, vertices, faces, color, opacity): |
|
|
"""获取或创建mesh(对象池)""" |
|
|
|
|
|
wireframe_mode = self.gui_controls.get('wireframe_mode', None) |
|
|
is_wireframe = wireframe_mode.value if wireframe_mode else False |
|
|
|
|
|
|
|
|
if is_wireframe: |
|
|
display_color = (255, 255, 255) |
|
|
display_opacity = 1.0 |
|
|
else: |
|
|
display_color = color |
|
|
display_opacity = opacity |
|
|
|
|
|
if key in self.mesh_handles_pool: |
|
|
mesh = self.mesh_handles_pool[key] |
|
|
mesh.vertices = vertices |
|
|
mesh.vertex_colors = None |
|
|
mesh.wireframe = is_wireframe |
|
|
mesh.opacity = display_opacity |
|
|
mesh.visible = True |
|
|
|
|
|
|
|
|
color_array = np.array(display_color, dtype=np.uint8) |
|
|
if color_array.max() <= 1.0: |
|
|
color_array = (color_array * 255).astype(np.uint8) |
|
|
mesh.color = tuple(color_array) |
|
|
else: |
|
|
|
|
|
color_array = np.array(display_color, dtype=np.uint8) |
|
|
if color_array.max() <= 1.0: |
|
|
color_array = (color_array * 255).astype(np.uint8) |
|
|
|
|
|
mesh = self.server.scene.add_mesh_simple( |
|
|
name=f"/mesh_{key}", |
|
|
vertices=vertices, |
|
|
faces=faces, |
|
|
color=tuple(color_array), |
|
|
opacity=display_opacity, |
|
|
wireframe=is_wireframe, |
|
|
flat_shading=False |
|
|
) |
|
|
self.mesh_handles_pool[key] = mesh |
|
|
|
|
|
return mesh |
|
|
|
|
|
def clear_visualization(self): |
|
|
"""清空可视化""" |
|
|
|
|
|
for mesh in self.mesh_handles_pool.values(): |
|
|
mesh.visible = False |
|
|
|
|
|
|
|
|
self.superquadric_handles = [] |
|
|
self.gt_superquadric_handles = [] |
|
|
|
|
|
|
|
|
if self.point_cloud_handle is not None: |
|
|
self.point_cloud_handle.remove() |
|
|
self.point_cloud_handle = None |
|
|
|
|
|
|
|
|
for handle in self.camera_frustum_handles: |
|
|
handle.remove() |
|
|
self.camera_frustum_handles = [] |
|
|
if self.camera_rgb_handle is not None: |
|
|
self.camera_rgb_handle.remove() |
|
|
self.camera_rgb_handle = None |
|
|
|
|
|
|
|
|
if self.coordinate_frame_handle is not None: |
|
|
self.coordinate_frame_handle.remove() |
|
|
self.coordinate_frame_handle = None |
|
|
|
|
|
|
|
|
for handle in self.object_label_handles: |
|
|
try: |
|
|
handle.remove() |
|
|
except (KeyError, AttributeError): |
|
|
|
|
|
pass |
|
|
self.object_label_handles = [] |
|
|
|
|
|
def _on_capture_camera(self, event): |
|
|
"""捕获当前相机视角""" |
|
|
clients = list(self.server.get_clients().values()) |
|
|
if not clients: |
|
|
print("⚠️ 没有连接的客户端") |
|
|
self.gui_controls['export_status'].value = "错误: 没有连接的客户端" |
|
|
return |
|
|
|
|
|
|
|
|
client = clients[0] |
|
|
self.export_camera_pos = np.array(client.camera.position) |
|
|
self.export_camera_wxyz = np.array(client.camera.wxyz) |
|
|
|
|
|
print(f"📸 已捕获相机视角: pos={self.export_camera_pos}, wxyz={self.export_camera_wxyz}") |
|
|
self.gui_controls['export_status'].value = f"已捕获视角: {self.export_camera_pos}" |
|
|
|
|
|
def _on_export_viser(self, event): |
|
|
"""导出为viser场景文件(可交互)""" |
|
|
if self.current_sample_path is None: |
|
|
print("⚠️ 请先加载样本") |
|
|
self.gui_controls['export_status'].value = "错误: 请先加载样本" |
|
|
return |
|
|
|
|
|
if self.original_frame_count <= 0: |
|
|
print("⚠️ 没有帧可以导出") |
|
|
self.gui_controls['export_status'].value = "错误: 没有帧可以导出" |
|
|
return |
|
|
|
|
|
|
|
|
threading.Thread(target=self._export_viser_thread, daemon=True).start() |
|
|
|
|
|
def _export_viser_thread(self): |
|
|
"""导出viser场景文件(带动画)""" |
|
|
try: |
|
|
print(f"\n{'='*60}") |
|
|
print(f"💾 开始导出Viser场景") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
|
|
|
clients = list(self.server.get_clients().values()) |
|
|
camera_params = None |
|
|
if clients: |
|
|
client = clients[0] |
|
|
cam_pos = client.camera.position |
|
|
cam_lookat = client.camera.look_at |
|
|
cam_up = client.camera.up_direction |
|
|
|
|
|
|
|
|
camera_params = ( |
|
|
f"&initialCameraPosition={cam_pos[0]:.3f},{cam_pos[1]:.3f},{cam_pos[2]:.3f}" |
|
|
f"&initialCameraLookAt={cam_lookat[0]:.3f},{cam_lookat[1]:.3f},{cam_lookat[2]:.3f}" |
|
|
f"&initialCameraUp={cam_up[0]:.3f},{cam_up[1]:.3f},{cam_up[2]:.3f}" |
|
|
) |
|
|
print(f" 📸 记录相机视角:") |
|
|
print(f" 位置: {cam_pos}") |
|
|
print(f" 朝向: {cam_lookat}") |
|
|
print(f" 向上: {cam_up}") |
|
|
|
|
|
|
|
|
fps = int(self.gui_controls['fps_slider'].value) |
|
|
|
|
|
|
|
|
output_dir = self.core_space_dir / "exports" |
|
|
output_dir.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
selected_output = self.gui_controls['output_selector'].value |
|
|
sample_idx = int(self.gui_controls['sample_slider'].value) |
|
|
step_info = "unknown" |
|
|
if "step" in selected_output: |
|
|
try: |
|
|
step_part = selected_output.split("_step")[1].split("_")[0] |
|
|
step_info = f"step{step_part}" |
|
|
except: |
|
|
pass |
|
|
|
|
|
timestamp = time.strftime("%Y%m%d_%H%M%S") |
|
|
experiment_name = selected_output.split("_")[0] |
|
|
output_file = output_dir / f"{experiment_name}_{step_info}_sample{sample_idx}_{timestamp}.viser" |
|
|
|
|
|
print(f" 输出文件: {output_file}") |
|
|
print(f" 帧数: {self.original_frame_count}") |
|
|
print(f" FPS: {fps}") |
|
|
|
|
|
|
|
|
serializer = self.server.get_scene_serializer() |
|
|
|
|
|
|
|
|
self.visualize_frame(0) |
|
|
serializer.insert_sleep(1.0 / fps) |
|
|
|
|
|
|
|
|
for frame_idx in range(1, self.original_frame_count): |
|
|
self.export_progress = int((frame_idx + 1) / self.original_frame_count * 100) |
|
|
self.gui_controls['export_status'].value = f"导出中... {self.export_progress}%" |
|
|
|
|
|
|
|
|
self.visualize_frame(frame_idx) |
|
|
|
|
|
|
|
|
serializer.insert_sleep(1.0 / fps) |
|
|
|
|
|
print(f" 记录帧 {frame_idx+1}/{self.original_frame_count}") |
|
|
|
|
|
|
|
|
data = serializer.serialize() |
|
|
output_file.write_bytes(data) |
|
|
|
|
|
print(f"✅ 场景导出完成: {output_file}") |
|
|
print(f" 文件大小: {len(data) / 1024 / 1024:.2f} MB") |
|
|
print(f"\n📖 查看方式:") |
|
|
print(f" 1. 安装viser客户端: viser-build-client --output-dir viser-client/") |
|
|
print(f" 2. 启动HTTP服务器: python -m http.server 8000") |
|
|
|
|
|
|
|
|
base_url = f"http://localhost:8000/viser-client/?playbackPath=http://localhost:8000/exports/{output_file.name}" |
|
|
if camera_params: |
|
|
full_url = base_url + camera_params |
|
|
print(f" 3. 打开浏览器(带相机视角):") |
|
|
print(f" {full_url}") |
|
|
else: |
|
|
print(f" 3. 打开浏览器:") |
|
|
print(f" {base_url}") |
|
|
|
|
|
relative_path = output_file.relative_to(self.core_space_dir) |
|
|
self.gui_controls['export_status'].value = f"完成! {relative_path}" |
|
|
|
|
|
|
|
|
clients = list(self.server.get_clients().values()) |
|
|
if clients: |
|
|
clients[0].send_file_download(output_file.name, data) |
|
|
print(f" 💾 已发送下载到浏览器") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ 导出失败: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
self.gui_controls['export_status'].value = f"错误: {str(e)}" |
|
|
|
|
|
def _on_export_video(self, event): |
|
|
"""导出视频""" |
|
|
if self.is_exporting: |
|
|
print("⚠️ 正在导出中,请等待...") |
|
|
return |
|
|
|
|
|
if self.current_sample_path is None: |
|
|
print("⚠️ 请先加载样本") |
|
|
self.gui_controls['export_status'].value = "错误: 请先加载样本" |
|
|
return |
|
|
|
|
|
if self.original_frame_count <= 0: |
|
|
print("⚠️ 没有帧可以导出") |
|
|
self.gui_controls['export_status'].value = "错误: 没有帧可以导出" |
|
|
return |
|
|
|
|
|
|
|
|
clients = list(self.server.get_clients().values()) |
|
|
if not clients: |
|
|
print("⚠️ 没有连接的客户端") |
|
|
self.gui_controls['export_status'].value = "错误: 请先在浏览器中打开viser界面" |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
client = clients[0] |
|
|
self.export_camera_pos = np.array(client.camera.position) |
|
|
self.export_camera_wxyz = np.array(client.camera.wxyz) |
|
|
print(f"📸 使用当前视角: pos={self.export_camera_pos}, wxyz={self.export_camera_wxyz}") |
|
|
|
|
|
|
|
|
threading.Thread(target=self._export_video_thread_screenshot, daemon=True).start() |
|
|
|
|
|
def _export_video_thread(self): |
|
|
"""视频导出线程""" |
|
|
try: |
|
|
self.is_exporting = True |
|
|
self.gui_controls['export_status'].value = "正在导出..." |
|
|
|
|
|
|
|
|
if not hasattr(self, 'scene_center') or self.scene_center is None: |
|
|
print(" 初始化场景参数...") |
|
|
self.visualize_frame(self.current_frame) |
|
|
|
|
|
|
|
|
fps = int(self.gui_controls['fps_slider'].value) |
|
|
resolution = int(self.gui_controls['export_resolution'].value) |
|
|
|
|
|
|
|
|
output_dir = self.core_space_dir / "exports" |
|
|
output_dir.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
selected_output = self.gui_controls['output_selector'].value |
|
|
sample_idx = int(self.gui_controls['sample_slider'].value) |
|
|
|
|
|
|
|
|
step_info = "unknown" |
|
|
if "step" in selected_output: |
|
|
try: |
|
|
step_part = selected_output.split("_step")[1].split("_")[0] |
|
|
step_info = f"step{step_part}" |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
timestamp = time.strftime("%Y%m%d_%H%M%S") |
|
|
experiment_name = selected_output.split("_")[0] |
|
|
output_file = output_dir / f"{experiment_name}_{step_info}_sample{sample_idx}_{timestamp}.mp4" |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print(f"🎬 开始导出视频") |
|
|
print(f"{'='*60}") |
|
|
print(f" 实验: {selected_output}") |
|
|
print(f" 样本: {sample_idx}") |
|
|
print(f" 输出文件: {output_file}") |
|
|
print(f" 帧数: {self.original_frame_count}") |
|
|
print(f" FPS: {fps}") |
|
|
print(f" 分辨率: {resolution}x{resolution}") |
|
|
print(f" 相机位置: {self.export_camera_pos}") |
|
|
print(f" 相机旋转: {self.export_camera_wxyz}") |
|
|
|
|
|
|
|
|
try: |
|
|
import imageio |
|
|
use_imageio = True |
|
|
print(" 使用 imageio 进行视频编码(H.264)") |
|
|
except ImportError: |
|
|
use_imageio = False |
|
|
print(" 使用 OpenCV 进行视频编码") |
|
|
|
|
|
if use_imageio: |
|
|
|
|
|
|
|
|
writer = imageio.get_writer( |
|
|
str(output_file), |
|
|
format='FFMPEG', |
|
|
mode='I', |
|
|
fps=fps, |
|
|
codec='libx264', |
|
|
pixelformat='yuv420p', |
|
|
output_params=['-crf', '18'] |
|
|
) |
|
|
|
|
|
|
|
|
for frame_idx in range(self.original_frame_count): |
|
|
self.export_progress = int((frame_idx + 1) / self.original_frame_count * 100) |
|
|
self.gui_controls['export_status'].value = f"导出中... {self.export_progress}%" |
|
|
|
|
|
|
|
|
frame_image = self._render_frame_offline( |
|
|
frame_idx, |
|
|
resolution=resolution, |
|
|
camera_pos=self.export_camera_pos, |
|
|
camera_wxyz=self.export_camera_wxyz |
|
|
) |
|
|
|
|
|
|
|
|
if frame_image is not None: |
|
|
writer.append_data(frame_image) |
|
|
|
|
|
print(f" 渲染帧 {frame_idx+1}/{self.original_frame_count}") |
|
|
|
|
|
writer.close() |
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
codecs_to_try = [ |
|
|
('H264', 'H.264'), |
|
|
('avc1', 'H.264 (AVC1)'), |
|
|
('X264', 'X264'), |
|
|
('mp4v', 'MPEG-4') |
|
|
] |
|
|
|
|
|
writer = None |
|
|
used_codec = None |
|
|
|
|
|
for codec_fourcc, codec_name in codecs_to_try: |
|
|
try: |
|
|
fourcc = cv2.VideoWriter_fourcc(*codec_fourcc) |
|
|
test_writer = cv2.VideoWriter( |
|
|
str(output_file), |
|
|
fourcc, |
|
|
fps, |
|
|
(resolution, resolution) |
|
|
) |
|
|
if test_writer.isOpened(): |
|
|
writer = test_writer |
|
|
used_codec = codec_name |
|
|
print(f" 使用编码器: {codec_name}") |
|
|
break |
|
|
else: |
|
|
test_writer.release() |
|
|
except: |
|
|
continue |
|
|
|
|
|
if writer is None: |
|
|
raise RuntimeError("无法初始化视频编码器") |
|
|
|
|
|
|
|
|
for frame_idx in range(self.original_frame_count): |
|
|
self.export_progress = int((frame_idx + 1) / self.original_frame_count * 100) |
|
|
self.gui_controls['export_status'].value = f"导出中... {self.export_progress}%" |
|
|
|
|
|
|
|
|
frame_image = self._render_frame_offline( |
|
|
frame_idx, |
|
|
resolution=resolution, |
|
|
camera_pos=self.export_camera_pos, |
|
|
camera_wxyz=self.export_camera_wxyz |
|
|
) |
|
|
|
|
|
|
|
|
if frame_image is not None: |
|
|
writer.write(cv2.cvtColor(frame_image, cv2.COLOR_RGB2BGR)) |
|
|
|
|
|
print(f" 渲染帧 {frame_idx+1}/{self.original_frame_count}") |
|
|
|
|
|
writer.release() |
|
|
|
|
|
print(f"✅ 视频导出完成: {output_file}") |
|
|
relative_path = output_file.relative_to(self.core_space_dir) |
|
|
self.gui_controls['export_status'].value = f"完成! {relative_path}" |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ 导出视频失败: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
self.gui_controls['export_status'].value = f"错误: {str(e)}" |
|
|
|
|
|
finally: |
|
|
self.is_exporting = False |
|
|
|
|
|
def _export_video_thread_screenshot(self): |
|
|
"""视频导出线程(基于截图viser界面)""" |
|
|
try: |
|
|
self.is_exporting = True |
|
|
self.gui_controls['export_status'].value = "正在导出..." |
|
|
|
|
|
|
|
|
fps = int(self.gui_controls['fps_slider'].value) |
|
|
|
|
|
|
|
|
output_dir = self.core_space_dir / "exports" |
|
|
output_dir.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
selected_output = self.gui_controls['output_selector'].value |
|
|
sample_idx = int(self.gui_controls['sample_slider'].value) |
|
|
step_info = "unknown" |
|
|
if "step" in selected_output: |
|
|
try: |
|
|
step_part = selected_output.split("_step")[1].split("_")[0] |
|
|
step_info = f"step{step_part}" |
|
|
except: |
|
|
pass |
|
|
|
|
|
timestamp = time.strftime("%Y%m%d_%H%M%S") |
|
|
experiment_name = selected_output.split("_")[0] |
|
|
output_file = output_dir / f"{experiment_name}_{step_info}_sample{sample_idx}_{timestamp}.mp4" |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print(f"🎬 开始导出视频(截图模式)") |
|
|
print(f"{'='*60}") |
|
|
print(f" 实验: {selected_output}") |
|
|
print(f" 样本: {sample_idx}") |
|
|
print(f" 输出文件: {output_file}") |
|
|
print(f" 帧数: {self.original_frame_count}") |
|
|
print(f" FPS: {fps}") |
|
|
print(f" 方法: 直接截取Viser显示画面") |
|
|
|
|
|
|
|
|
try: |
|
|
from selenium import webdriver |
|
|
from selenium.webdriver.chrome.options import Options |
|
|
from selenium.webdriver.common.by import By |
|
|
import time as time_module |
|
|
use_selenium = True |
|
|
print(" ✅ 使用 Selenium 截图") |
|
|
except ImportError: |
|
|
print(" ⚠️ Selenium未安装,使用逐帧渲染方法") |
|
|
print(" 提示: pip install selenium") |
|
|
use_selenium = False |
|
|
|
|
|
if use_selenium: |
|
|
|
|
|
frames = [] |
|
|
|
|
|
|
|
|
chrome_options = Options() |
|
|
chrome_options.add_argument('--headless') |
|
|
chrome_options.add_argument('--no-sandbox') |
|
|
chrome_options.add_argument('--disable-dev-shm-usage') |
|
|
chrome_options.add_argument('--window-size=1920,1080') |
|
|
|
|
|
try: |
|
|
driver = webdriver.Chrome(options=chrome_options) |
|
|
url = f"http://localhost:{self.port}" |
|
|
driver.get(url) |
|
|
print(f" 📱 打开浏览器: {url}") |
|
|
|
|
|
|
|
|
time_module.sleep(3) |
|
|
|
|
|
|
|
|
for frame_idx in range(self.original_frame_count): |
|
|
self.export_progress = int((frame_idx + 1) / self.original_frame_count * 100) |
|
|
self.gui_controls['export_status'].value = f"截图中... {self.export_progress}%" |
|
|
|
|
|
|
|
|
self.gui_controls['frame_slider'].value = frame_idx |
|
|
time_module.sleep(0.3) |
|
|
|
|
|
|
|
|
screenshot = driver.get_screenshot_as_png() |
|
|
img = cv2.imdecode(np.frombuffer(screenshot, np.uint8), cv2.IMREAD_COLOR) |
|
|
frames.append(img) |
|
|
|
|
|
print(f" 截图帧 {frame_idx+1}/{self.original_frame_count}") |
|
|
|
|
|
driver.quit() |
|
|
|
|
|
|
|
|
try: |
|
|
import imageio |
|
|
writer = imageio.get_writer( |
|
|
str(output_file), |
|
|
format='FFMPEG', |
|
|
mode='I', |
|
|
fps=fps, |
|
|
codec='libx264', |
|
|
pixelformat='yuv420p', |
|
|
output_params=['-crf', '18'] |
|
|
) |
|
|
|
|
|
for frame in frames: |
|
|
|
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
writer.append_data(frame_rgb) |
|
|
|
|
|
writer.close() |
|
|
print(f"✅ 视频导出完成: {output_file}") |
|
|
relative_path = output_file.relative_to(self.core_space_dir) |
|
|
self.gui_controls['export_status'].value = f"完成! {relative_path}" |
|
|
|
|
|
except ImportError: |
|
|
|
|
|
height, width = frames[0].shape[:2] |
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
|
writer = cv2.VideoWriter(str(output_file), fourcc, fps, (width, height)) |
|
|
for frame in frames: |
|
|
writer.write(frame) |
|
|
writer.release() |
|
|
print(f"✅ 视频导出完成: {output_file}") |
|
|
relative_path = output_file.relative_to(self.core_space_dir) |
|
|
self.gui_controls['export_status'].value = f"完成! {relative_path}" |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Selenium截图失败: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
use_selenium = False |
|
|
|
|
|
if not use_selenium: |
|
|
|
|
|
print(" 使用PyRender离线渲染...") |
|
|
self._export_video_thread() |
|
|
return |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ 导出视频失败: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
self.gui_controls['export_status'].value = f"错误: {str(e)}" |
|
|
|
|
|
finally: |
|
|
self.is_exporting = False |
|
|
|
|
|
def _render_frame_offline(self, frame_idx: int, resolution: int, |
|
|
camera_pos: np.ndarray, camera_wxyz: np.ndarray) -> Optional[np.ndarray]: |
|
|
"""离线渲染一帧""" |
|
|
|
|
|
try: |
|
|
import pyrender |
|
|
import trimesh |
|
|
except ImportError: |
|
|
if frame_idx == 0: |
|
|
print("⚠️ pyrender未安装,使用简化渲染...") |
|
|
print(" 提示: 安装 pyrender 以获得完整3D渲染") |
|
|
print(" pip install pyrender trimesh") |
|
|
return self._render_frame_simple(frame_idx, resolution) |
|
|
|
|
|
|
|
|
|
|
|
for platform in ['egl', 'osmesa']: |
|
|
try: |
|
|
os.environ['PYOPENGL_PLATFORM'] = platform |
|
|
|
|
|
|
|
|
scene = pyrender.Scene( |
|
|
ambient_light=[0.3, 0.3, 0.3], |
|
|
bg_color=[13/255, 13/255, 38/255, 1.0] |
|
|
) |
|
|
|
|
|
|
|
|
show_generated = self.gui_controls['show_generated'].value |
|
|
show_gt = self.gui_controls['show_gt'].value |
|
|
generated_color = np.array(self.gui_controls['generated_color'].value) / 255.0 |
|
|
gt_color = np.array(self.gui_controls['gt_color'].value) / 255.0 |
|
|
mesh_resolution = int(self.gui_controls['mesh_resolution'].value) |
|
|
|
|
|
mesh_count = 0 |
|
|
|
|
|
|
|
|
|
|
|
if show_generated: |
|
|
predictions = self._extract_predictions(frame_idx) |
|
|
if predictions is not None: |
|
|
for obj_idx, obj_params in enumerate(predictions): |
|
|
if obj_params[0] > 0.5: |
|
|
|
|
|
obj_params_normalized = obj_params.copy() |
|
|
|
|
|
translation = obj_params[6:9] |
|
|
translation_normalized = (translation - self.scene_center) * self.scene_scale |
|
|
obj_params_normalized[6:9] = translation_normalized |
|
|
|
|
|
obj_params_normalized[3:6] = obj_params[3:6] * self.scene_scale |
|
|
|
|
|
vertices, faces = self.generate_superquadric_mesh( |
|
|
obj_params_normalized, num_samples=mesh_resolution |
|
|
) |
|
|
|
|
|
if frame_idx == 0 and obj_idx == 0: |
|
|
print(f" 物体原始位置: {translation}") |
|
|
print(f" 物体归一化位置: {translation_normalized}") |
|
|
print(f" 场景中心: {self.scene_center}, 缩放: {self.scene_scale}") |
|
|
|
|
|
mesh = trimesh.Trimesh(vertices=vertices, faces=faces) |
|
|
|
|
|
num_verts = len(vertices) |
|
|
vertex_colors = np.zeros((num_verts, 4), dtype=np.uint8) |
|
|
vertex_colors[:, :3] = (generated_color * 255).astype(np.uint8) |
|
|
vertex_colors[:, 3] = 255 |
|
|
mesh.visual.vertex_colors = vertex_colors |
|
|
|
|
|
|
|
|
material = pyrender.MetallicRoughnessMaterial( |
|
|
baseColorFactor=list(generated_color) + [1.0], |
|
|
metallicFactor=0.3, |
|
|
roughnessFactor=0.7 |
|
|
) |
|
|
mesh_obj = pyrender.Mesh.from_trimesh(mesh, material=material) |
|
|
scene.add(mesh_obj) |
|
|
mesh_count += 1 |
|
|
|
|
|
|
|
|
if show_gt: |
|
|
targets = self._extract_targets(frame_idx) |
|
|
if targets is not None: |
|
|
for obj_idx, obj_params in enumerate(targets): |
|
|
if obj_params[0] > 0.5: |
|
|
|
|
|
obj_params_normalized = obj_params.copy() |
|
|
translation = obj_params[6:9] |
|
|
translation_normalized = (translation - self.scene_center) * self.scene_scale |
|
|
obj_params_normalized[6:9] = translation_normalized |
|
|
obj_params_normalized[3:6] = obj_params[3:6] * self.scene_scale |
|
|
|
|
|
vertices, faces = self.generate_superquadric_mesh( |
|
|
obj_params_normalized, num_samples=mesh_resolution |
|
|
) |
|
|
mesh = trimesh.Trimesh(vertices=vertices, faces=faces) |
|
|
|
|
|
num_verts = len(vertices) |
|
|
vertex_colors = np.zeros((num_verts, 4), dtype=np.uint8) |
|
|
vertex_colors[:, :3] = (gt_color * 255).astype(np.uint8) |
|
|
vertex_colors[:, 3] = 255 |
|
|
mesh.visual.vertex_colors = vertex_colors |
|
|
|
|
|
|
|
|
material = pyrender.MetallicRoughnessMaterial( |
|
|
baseColorFactor=list(gt_color) + [0.5], |
|
|
metallicFactor=0.3, |
|
|
roughnessFactor=0.7 |
|
|
) |
|
|
mesh_obj = pyrender.Mesh.from_trimesh(mesh, material=material) |
|
|
scene.add(mesh_obj) |
|
|
mesh_count += 1 |
|
|
|
|
|
if frame_idx == 0: |
|
|
print(f" 场景中添加了 {mesh_count} 个mesh") |
|
|
|
|
|
|
|
|
|
|
|
from scipy.spatial.transform import Rotation as R |
|
|
|
|
|
|
|
|
rot = R.from_quat([camera_wxyz[1], camera_wxyz[2], camera_wxyz[3], camera_wxyz[0]]) |
|
|
rot_matrix = rot.as_matrix() |
|
|
|
|
|
|
|
|
|
|
|
camera_pose = np.eye(4) |
|
|
camera_pose[:3, :3] = rot_matrix |
|
|
camera_pose[:3, 3] = camera_pos |
|
|
|
|
|
if frame_idx == 0: |
|
|
print(f" 相机位置: {camera_pos}") |
|
|
print(f" 相机旋转矩阵:\n{rot_matrix}") |
|
|
|
|
|
|
|
|
camera = pyrender.PerspectiveCamera(yfov=np.pi / 3.0, aspectRatio=1.0) |
|
|
scene.add(camera, pose=camera_pose) |
|
|
|
|
|
|
|
|
|
|
|
light1 = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=2.0) |
|
|
scene.add(light1, pose=camera_pose) |
|
|
|
|
|
|
|
|
light2 = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=1.0) |
|
|
light_pose = np.eye(4) |
|
|
light_pose[:3, 3] = [10, 10, 10] |
|
|
scene.add(light2, pose=light_pose) |
|
|
|
|
|
|
|
|
renderer = pyrender.OffscreenRenderer(resolution, resolution) |
|
|
color, depth = renderer.render(scene) |
|
|
renderer.delete() |
|
|
|
|
|
|
|
|
if frame_idx == 0: |
|
|
print(f" ✅ 使用 {platform.upper()} 进行离线渲染") |
|
|
print(f" 渲染输出范围: [{color.min()}, {color.max()}]") |
|
|
print(f" 深度范围: [{depth.min()}, {depth.max()}]") |
|
|
|
|
|
return color |
|
|
|
|
|
except Exception as e: |
|
|
if platform == 'osmesa': |
|
|
|
|
|
if frame_idx == 0: |
|
|
print(f"❌ PyRender渲染失败 (EGL和OSMesa都不可用): {e}") |
|
|
print(" 使用简化渲染模式...") |
|
|
return self._render_frame_simple(frame_idx, resolution) |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
return self._render_frame_simple(frame_idx, resolution) |
|
|
|
|
|
def _render_frame_simple(self, frame_idx: int, resolution: int) -> np.ndarray: |
|
|
"""简化渲染(纯色背景 + 文字提示)""" |
|
|
|
|
|
image = np.full((resolution, resolution, 3), [13, 13, 38], dtype=np.uint8) |
|
|
|
|
|
|
|
|
text = f"Frame {frame_idx + 1}/{self.original_frame_count}" |
|
|
font = cv2.FONT_HERSHEY_SIMPLEX |
|
|
text_size = cv2.getTextSize(text, font, 1, 2)[0] |
|
|
text_x = (resolution - text_size[0]) // 2 |
|
|
text_y = (resolution + text_size[1]) // 2 |
|
|
|
|
|
cv2.putText(image, text, (text_x, text_y), font, 1, (255, 255, 255), 2) |
|
|
|
|
|
|
|
|
hint = "Install pyrender for full rendering" |
|
|
hint_size = cv2.getTextSize(hint, font, 0.5, 1)[0] |
|
|
hint_x = (resolution - hint_size[0]) // 2 |
|
|
hint_y = text_y + 40 |
|
|
|
|
|
cv2.putText(image, hint, (hint_x, hint_y), font, 0.5, (150, 150, 150), 1) |
|
|
|
|
|
return image |
|
|
|
|
|
def run(self, auto_open_browser: bool = True): |
|
|
"""运行可视化器""" |
|
|
print("\n" + "="*60) |
|
|
print("🎨 WaveGen 训练可视化器") |
|
|
print("="*60) |
|
|
print(f"📁 监控目录: {self.core_space_dir}") |
|
|
print(f"🌐 Web界面: http://localhost:{self.port}") |
|
|
print("="*60) |
|
|
print("\n💡 提示:") |
|
|
print(" - 如果页面空白一直加载,请刷新浏览器 (Ctrl+Shift+R)") |
|
|
print(" - 建议使用 Chrome 或 Firefox 浏览器") |
|
|
print("\n按 Ctrl+C 退出\n") |
|
|
|
|
|
|
|
|
if auto_open_browser: |
|
|
url = f"http://localhost:{self.port}" |
|
|
print(f"🌐 正在打开浏览器: {url}") |
|
|
try: |
|
|
webbrowser.open(url) |
|
|
except Exception as e: |
|
|
print(f"⚠️ 无法自动打开浏览器: {e}") |
|
|
print(f" 请手动访问: {url}") |
|
|
|
|
|
try: |
|
|
while True: |
|
|
time.sleep(0.1) |
|
|
except KeyboardInterrupt: |
|
|
print("\n👋 再见!") |
|
|
print("正在关闭服务器...") |
|
|
|
|
|
try: |
|
|
for mesh in self.mesh_handles_pool.values(): |
|
|
mesh.remove() |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""主函数""" |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="WaveGen训练结果可视化工具") |
|
|
parser.add_argument( |
|
|
'--core-space', |
|
|
type=str, |
|
|
default='core_space', |
|
|
help='core_space目录路径(默认: ./core_space)' |
|
|
) |
|
|
parser.add_argument( |
|
|
'--port', |
|
|
type=int, |
|
|
default=8080, |
|
|
help='Viser服务器端口(默认: 8080,如果被占用会自动尝试下一个端口)' |
|
|
) |
|
|
parser.add_argument( |
|
|
'--no-browser', |
|
|
action='store_true', |
|
|
help='不自动打开浏览器' |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
visualizer = TrainingVisualizer(core_space_dir=args.core_space, port=args.port) |
|
|
visualizer.run(auto_open_browser=not args.no_browser) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|