""" 保存生成结果的工具函数 用于WaveGen v33 - 使用超二次元函数的版本 """ import os import numpy as np import torch from pathlib import Path from typing import Dict, List, Optional import json from datetime import datetime import shutil def save_generation_results( predictions: List[Dict], targets: Dict[str, torch.Tensor], texts: List[str], step: int, output_dir: str = "outputs", save_config: Dict = None, metadata: Dict = None, batch_data: Dict = None, # 新增:完整的批次数据 data_root: str = None, # 新增:原始数据根目录 data_split: str = "validation" # 新增:数据集split(train/validation) ): """ 保存生成结果用于可视化和分析(增强版) Args: predictions: 模型预测结果列表,每个元素包含 'objects', 'world', 'physics' targets: 真实目标数据 texts: 输入文本描述 step: 当前训练步数 output_dir: 输出目录 save_config: 保存配置 metadata: 额外的元数据(如序列名称、相机参数等) batch_data: 完整的批次数据,用于获取更多原始信息 data_root: 原始数据根目录,用于复制原始文件 data_split: 数据集split('train' 或 'validation'),用于确定原始数据的正确位置 """ # 创建时间戳目录 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") save_path = Path(output_dir) / f"{timestamp}_step{step}_text2wave" save_path.mkdir(parents=True, exist_ok=True) # 保存配置 if save_config is not None: with open(save_path / "save_config.json", 'w') as f: json.dump(save_config, f, indent=2) batch_size = len(texts) num_frames = len(predictions) print(f"\n{'='*60}") print(f"💾 保存生成结果: step {step}") print(f"📁 保存路径: {save_path}") print(f"📊 样本数: {batch_size}, 帧数: {num_frames}") print(f"{'='*60}\n") for i in range(batch_size): try: sample_dir = save_path / f"sample_{i}" sample_dir.mkdir(exist_ok=True) print(f"正在保存 sample_{i}... ", end='', flush=True) # 1. 保存文本描述和元信息 with open(sample_dir / "info.txt", 'w') as f: f.write(f"Text: {texts[i]}\n") f.write(f"Generated at step: {step}\n") f.write(f"Number of frames: {num_frames}\n") if metadata and 'sequence_names' in metadata and metadata['sequence_names'] is not None: f.write(f"Sequence: {metadata['sequence_names'][i]}\n") f.write("\n--- Model Output Summary ---\n") f.write(f"Max objects: {predictions[0]['objects'].shape[1]}\n") f.write("Object parameters: 15 (exists + shape[2] + scale[3] + translation[3] + rotation[3] + velocity[3])\n") f.write(f"World parameters: 8 (camera_pos[3] + camera_quat[4] + scene_scale[1])\n") f.write(f"Physics parameters: 3 (mass + friction + restitution)\n") # 2. 保存生成的超二次元函数参数(改进格式) frame_predictions = [] physics_per_frame = [] # 保存每一帧的physics预测,便于查看全序列 for f_idx, pred in enumerate(predictions): # predictions列表中每个元素包含当前帧、完整batch的数据,这里先取出当前样本 objects_batch = pred['objects'] world_batch = pred['world'] physics_batch = pred.get('physics') objects_params = objects_batch[i].cpu().numpy() if hasattr(objects_batch, 'cpu') else objects_batch[i] # [max_objects, 15] world_params = world_batch[i].cpu().numpy() if hasattr(world_batch, 'cpu') else world_batch[i] # [8] if physics_batch is not None: physics_params = physics_batch[i].cpu().numpy() if hasattr(physics_batch, 'cpu') else physics_batch[i] # [max_objects, 3] else: physics_params = np.zeros((objects_params.shape[0], 3), dtype=np.float32) # 保存当前帧的 physics(每帧各物体的mass/friction/restitution) physics_per_frame.append([ { 'mass': float(phys_params[0]), 'friction': float(phys_params[1]), 'restitution': float(phys_params[2]), } for phys_params in physics_params ]) # 将物体参数转换为更易读的格式 superquadrics = [] for obj_idx in range(objects_params.shape[0]): obj_params = objects_params[obj_idx] phys_params = physics_params[obj_idx] superquadric = { 'exists': bool(obj_params[0] > 0.5), # exists flag 'shape': obj_params[1:3], # epsilon1, epsilon2 'scale': obj_params[3:6], # a, b, c 'translation': obj_params[6:9], # x, y, z 'rotation': obj_params[9:12], # euler angles # 预测没有 inlier_ratio,填充0以保持键一致 'inlier_ratio': 0.0, 'velocity': obj_params[12:15], # vx, vy, vz 'mass': phys_params[0], 'friction': phys_params[1], 'restitution': phys_params[2], } superquadrics.append(superquadric) # 将世界参数转换为更易读的格式 world_info = { 'camera_position': world_params[0:3], # x, y, z 'camera_quaternion': world_params[3:7], # w, x, y, z 'scene_scale': float(world_params[7]), # scale # 预测没有scene_center,填零保持字段一致,方便下游读取 'scene_center': np.zeros(3, dtype=np.float32), } frame_data = { 'frame_idx': f_idx, 'superquadrics': superquadrics, 'world_info': world_info, } frame_predictions.append(frame_data) np.savez(sample_dir / "predictions.npz", text=texts[i], frames=frame_predictions, num_frames=num_frames, # 保存全序列physics;未预测则写None physics=physics_per_frame if physics_per_frame else None, sequence_name=metadata['sequence_names'][i] if (metadata and 'sequence_names' in metadata and metadata['sequence_names'] is not None) else "unknown", description="Predicted superquadric parameters for each frame") # 3. 保存真实目标数据(如果有)- 改进格式 if targets is not None: # targets中的数据已经是完整批次,需要索引[i]获取当前样本 target_objects = targets['objects'][i].cpu().numpy() if hasattr(targets['objects'], 'cpu') else targets['objects'][i] # [num_frames, max_objects, 16] target_world = targets['world'][i].cpu().numpy() if hasattr(targets['world'], 'cpu') else targets['world'][i] # [num_frames, 11] if 'physics' in targets and targets['physics'] is not None: target_physics = targets['physics'][i].cpu().numpy() if hasattr(targets['physics'], 'cpu') else targets['physics'][i] else: target_physics = None # 生成顶层 physics,与原始 Full_Sample_Data_for_Learning_Target 一致 target_physics_top = None if target_physics is not None: target_physics_top = [ { 'mass': float(p[0]), 'friction': float(p[1]), 'restitution': float(p[2]), } for p in target_physics ] # 将目标数据转换为更易读的格式 target_frames = [] for f_idx in range(target_objects.shape[0]): frame_objects = target_objects[f_idx] # [max_objects, 16] frame_world = target_world[f_idx] # [11] # 转换物体参数 superquadrics = [] for obj_idx in range(frame_objects.shape[0]): obj_params = frame_objects[obj_idx] phys_params = target_physics[obj_idx] if target_physics is not None else np.zeros(3) superquadric = { 'exists': bool(obj_params[0] > 0.5), # exists flag 'shape': obj_params[1:3], # epsilon1, epsilon2 'scale': obj_params[3:6], # a, b, c 'translation': obj_params[6:9], # x, y, z 'rotation': obj_params[9:12], # euler angles 'inlier_ratio': float(obj_params[12]), # GT specific: inlier ratio 'velocity': obj_params[13:16], # vx, vy, vz 'mass': phys_params[0], 'friction': phys_params[1], 'restitution': phys_params[2], } superquadrics.append(superquadric) # 转换世界参数 world_info = { 'camera_position': frame_world[0:3], # x, y, z 'camera_quaternion': frame_world[3:7], # w, x, y, z 'scene_scale': float(frame_world[7]), # scale 'scene_center': frame_world[8:11], # center x, y, z } frame_data = { 'frame_idx': f_idx, 'superquadrics': superquadrics, 'world_info': world_info, } target_frames.append(frame_data) # 保存改进格式的 targets.npz np.savez(sample_dir / "targets.npz", text=texts[i], frames=target_frames, num_frames=num_frames, physics=target_physics_top if target_physics_top is not None else None, sequence_name=metadata['sequence_names'][i] if (metadata and 'sequence_names' in metadata and metadata['sequence_names'] is not None) else "unknown", description="Ground truth superquadric parameters for each frame") # 为了兼容性,也保存原始格式(用于误差计算) target_data_legacy = { 'objects': target_objects, 'world': target_world, 'physics': target_physics, } # 计算并保存误差统计(使用旧格式) save_error_statistics(frame_predictions, target_data_legacy, sample_dir) # 4. 保存相机参数(如果有) if metadata and 'camera_data' in metadata: camera_data = metadata['camera_data'][i] np.savez(sample_dir / "camera_params.npz", **camera_data) # 5. 保存原始数据(新增功能,不再依赖camera_data) if batch_data is not None and data_root is not None: save_original_data(sample_dir, i, batch_data, metadata, data_root, data_split) print("✅") # 完成标记 except Exception as e: print(f"❌ 错误: {e}") import traceback traceback.print_exc() continue # 继续保存其他样本 # 注意:已移除save_visualization_script,因为不需要单独的可视化脚本 # 保存整体统计信息 save_batch_statistics(predictions, targets, save_path) print(f"\n{'='*60}") print(f"✅ 保存完成!") print(f"📁 保存路径: {save_path}") print(f"{'='*60}\n") return save_path def save_error_statistics(predictions: List[Dict], targets: Dict, save_dir: Path): """计算并保存预测误差统计 Args: predictions: 新格式的帧列表 (包含 superquadrics 和 world_info) targets: 旧格式的目标数据 (包含 objects, world 数组) """ stats = {} # 将新格式的 predictions 转换回数组格式进行误差计算 object_errors = [] world_errors = [] for frame in predictions: frame_idx = frame['frame_idx'] # 从新格式重建物体数组 superquadrics = frame['superquadrics'] pred_objects = [] for sq in superquadrics: obj_params = np.zeros(15, dtype=np.float32) obj_params[0] = 1.0 if sq['exists'] else 0.0 obj_params[1:3] = sq['shape'] obj_params[3:6] = sq['scale'] obj_params[6:9] = sq['translation'] obj_params[9:12] = sq['rotation'] obj_params[12:15] = sq['velocity'] pred_objects.append(obj_params) pred_obj = np.array(pred_objects) # 获取目标物体数据 target_obj_full = targets['objects'][frame_idx] target_obj = target_obj_full[:, :pred_obj.shape[1]] # 对齐模型预测的维度 # 只计算存在的物体 exists_mask = target_obj[:, 0] > 0.5 if exists_mask.any(): error = np.mean(np.abs(pred_obj[exists_mask] - target_obj[exists_mask])) object_errors.append(error) # 从新格式重建世界参数数组 world_info = frame['world_info'] pred_world = np.concatenate([ world_info['camera_position'], world_info['camera_quaternion'], [world_info['scene_scale']] ]) # 获取目标世界数据 target_world = targets['world'][frame_idx][:8] error = np.mean(np.abs(pred_world - target_world)) world_errors.append(error) stats['object_mae'] = float(np.mean(object_errors)) if object_errors else 0.0 stats['world_mae'] = float(np.mean(world_errors)) # 保存统计信息 with open(save_dir / "error_statistics.json", 'w') as f: json.dump(stats, f, indent=2) def save_batch_statistics(predictions: List[Dict], targets: Dict, save_dir: Path): """保存整批数据的统计信息""" batch_size = predictions[0]['objects'].shape[0] stats = { 'batch_size': batch_size, 'num_frames': len(predictions), 'timestamp': datetime.now().isoformat(), } # 统计每帧实际存在的物体数量 if targets is not None: objects_per_frame = [] for f_idx in range(len(predictions)): frame_objects = [] for b_idx in range(batch_size): exists = targets['objects'][b_idx, f_idx, :, 0] > 0.5 frame_objects.append(int(exists.sum())) objects_per_frame.append({ 'frame': f_idx, 'mean_objects': float(np.mean(frame_objects)), 'max_objects': int(max(frame_objects)), 'min_objects': int(min(frame_objects)), }) stats['objects_per_frame'] = objects_per_frame with open(save_dir / "batch_statistics.json", 'w') as f: json.dump(stats, f, indent=2) def save_visualization_script(save_dir: Path): """保存用于可视化超二次元函数的Python脚本""" script = '''#!/usr/bin/env python3 """ 可视化生成的超二次元函数参数 """ import numpy as np import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D def superquadric_surface(a, b, c, e1, e2, n=50): """生成超二次元函数表面点""" eta = np.linspace(-np.pi/2, np.pi/2, n) omega = np.linspace(-np.pi, np.pi, n) eta, omega = np.meshgrid(eta, omega) x = a * np.sign(np.cos(eta)) * np.abs(np.cos(eta))**e1 * np.sign(np.cos(omega)) * np.abs(np.cos(omega))**e2 y = b * np.sign(np.cos(eta)) * np.abs(np.cos(eta))**e1 * np.sign(np.sin(omega)) * np.abs(np.sin(omega))**e2 z = c * np.sign(np.sin(eta)) * np.abs(np.sin(eta))**e1 return x, y, z # 加载预测数据 data = np.load('predictions.npz', allow_pickle=True) frames = data['frames'] # 可视化第一帧 frame = frames[0] objects = frame['objects'] # [max_objects, 12] fig = plt.figure(figsize=(12, 8)) ax = fig.add_subplot(111, projection='3d') # 绘制每个存在的物体 for obj_idx, obj_params in enumerate(objects): if obj_params[0] > 0.5: # 物体存在 # 提取参数 shape_params = obj_params[1:3] # e1, e2 scale = obj_params[3:6] # a, b, c translation = obj_params[6:9] rotation = obj_params[9:12] # 简化处理,暂不应用旋转 # 生成表面 x, y, z = superquadric_surface(scale[0], scale[1], scale[2], shape_params[0], shape_params[1]) # 应用平移 x += translation[0] y += translation[1] z += translation[2] # 绘制 ax.plot_surface(x, y, z, alpha=0.7, label=f'Object {obj_idx}') ax.set_xlabel('X') ax.set_ylabel('Y') ax.set_zlabel('Z') ax.set_title('Generated Superquadric Objects') plt.savefig('visualization.png') plt.show() ''' with open(save_dir / "visualize.py", 'w') as f: f.write(script) # 使脚本可执行 os.chmod(save_dir / "visualize.py", 0o755) def save_original_data(sample_dir: Path, sample_idx: int, batch_data: Dict, metadata: Dict, data_root: str, data_split: str = "validation"): """ 保存原始数据文件,包括RGB图像、深度图、分割图、点云等 Args: sample_dir: 当前样本保存目录 sample_idx: 批次中的样本索引 batch_data: 批次数据 metadata: 元数据 data_root: MOVi数据集根目录 data_split: 数据集split('train' 或 'validation') """ try: # 获取原始序列名称 sequence_name = None if metadata and 'sequence_names' in metadata and metadata['sequence_names'] is not None: sequence_name = metadata['sequence_names'][sample_idx] if not sequence_name: print(f"Warning: No sequence name found for sample {sample_idx}") # 创建一个说明文件,解释为什么没有原始数据 no_original_data_file = sample_dir / "no_original_data.txt" with open(no_original_data_file, 'w') as f: f.write("原始数据未保存,因为无法获取序列名称。\n") f.write("这可能是因为数据加载器没有返回sequence_names字段。\n") f.write(f"Sample index: {sample_idx}\n") f.write(f"Generated at step: {datetime.now().isoformat()}\n") return # 构建原始数据路径 data_root_path = Path(data_root) # 直接使用指定的data_split来定位原始数据 original_sample_dir = data_root_path / data_split / sequence_name if not original_sample_dir.exists(): print(f"Warning: Could not find original data for {sequence_name} in {data_split} split") # 创建一个说明文件 error_file = sample_dir / "original_data_not_found.txt" with open(error_file, 'w') as f: f.write(f"原始数据未找到\n") f.write(f"查找路径: {original_sample_dir}\n") f.write(f"数据集: {data_split}\n") f.write(f"序列名: {sequence_name}\n") f.write(f"时间: {datetime.now().isoformat()}\n") return print(f"Copying original data from {original_sample_dir}") # 创建原始数据子目录 original_data_dir = sample_dir / "original_data" original_data_dir.mkdir(exist_ok=True) # 1. 复制RGB图像(所有帧) rgb_dir = original_data_dir / "rgb" rgb_dir.mkdir(exist_ok=True) original_rgb_dir = original_sample_dir / "rgb" if original_rgb_dir.exists(): for rgb_file in sorted(original_rgb_dir.glob("frame_*.png")): shutil.copy2(rgb_file, rgb_dir / rgb_file.name) # 2. 复制深度图(优先复制合并的npz,否则复制单独的npy) depth_dir = original_data_dir / "depth" depth_dir.mkdir(exist_ok=True) original_depth_dir = original_sample_dir / "depth" if original_depth_dir.exists(): # 检查是否有合并的npz文件 merged_depth = original_depth_dir / "depth_merge.npz" if merged_depth.exists(): shutil.copy2(merged_depth, depth_dir / "depth_merge.npz") else: # 复制单独的npy文件 for depth_file in sorted(original_depth_dir.glob("frame_*.npy")): shutil.copy2(depth_file, depth_dir / depth_file.name) # 3. 复制分割图(优先复制合并的npz,否则复制单独的npy) seg_dir = original_data_dir / "segmentation" seg_dir.mkdir(exist_ok=True) original_seg_dir = original_sample_dir / "segmentation" if original_seg_dir.exists(): # 检查是否有合并的npz文件 merged_seg = original_seg_dir / "segmentation_merge.npz" if merged_seg.exists(): shutil.copy2(merged_seg, seg_dir / "segmentation_merge.npz") else: # 复制单独的npy文件 for seg_file in sorted(original_seg_dir.glob("frame_*.npy")): shutil.copy2(seg_file, seg_dir / seg_file.name) # 4. 复制法线图(优先复制合并的npz,否则复制单独的npy) normal_dir = original_data_dir / "normal" normal_dir.mkdir(exist_ok=True) original_normal_dir = original_sample_dir / "normal" if original_normal_dir.exists(): # 检查是否有合并的npz文件 merged_normal = original_normal_dir / "normal_merge.npz" if merged_normal.exists(): shutil.copy2(merged_normal, normal_dir / "normal_merge.npz") else: # 复制单独的npy文件 for normal_file in sorted(original_normal_dir.glob("frame_*.npy")): shutil.copy2(normal_file, normal_dir / normal_file.name) # 5. 复制相机轨迹 camera_traj_file = original_sample_dir / "camera_trajectory.npz" if camera_traj_file.exists(): shutil.copy2(camera_traj_file, original_data_dir / "camera_trajectory.npz") # 6. 复制元数据JSON metadata_file = original_sample_dir / "metadata.json" if metadata_file.exists(): shutil.copy2(metadata_file, original_data_dir / "metadata.json") # 7. 复制完整的训练目标数据缓存文件(Full_Sample_Data_for_Learning_Target.npz) full_cache_file = original_sample_dir / "Full_Sample_Data_for_Learning_Target.npz" if full_cache_file.exists(): shutil.copy2(full_cache_file, original_data_dir / "Full_Sample_Data_for_Learning_Target.npz") # 8. 复制其他可能的合并文件(object_coordinates, point_clouds等) for folder_name in ['object_coordinates', 'point_clouds']: folder_dir = original_data_dir / folder_name original_folder_dir = original_sample_dir / folder_name if original_folder_dir.exists(): folder_dir.mkdir(exist_ok=True) # 检查是否有合并的npz文件 merged_file = original_folder_dir / f"{folder_name}_merge.npz" if merged_file.exists(): shutil.copy2(merged_file, folder_dir / f"{folder_name}_merge.npz") else: # 复制单独的npy文件 for npy_file in sorted(original_folder_dir.glob("frame_*.npy")): shutil.copy2(npy_file, folder_dir / npy_file.name) # 9. 如果有预处理的点云数据(在batch_data中),也保存 if 'point_clouds' in batch_data: pc_data = batch_data['point_clouds'][sample_idx] np.savez_compressed(original_data_dir / "point_clouds.npz", **pc_data) # 10. 保存场景归一化参数 if 'scene_normalization' in batch_data: norm_params = batch_data['scene_normalization'][sample_idx] with open(original_data_dir / "scene_normalization.json", 'w') as f: json.dump({ 'scene_center': norm_params['center'].tolist() if hasattr(norm_params['center'], 'tolist') else norm_params['center'], 'scene_scale': float(norm_params['scale']) if 'scale' in norm_params else 1.0, 'scene_extent': float(norm_params['extent']) if 'extent' in norm_params else 1.0 }, f, indent=2) # 11. 创建文件清单 with open(original_data_dir / "file_manifest.txt", 'w') as f: f.write(f"Original sequence: {sequence_name}\n") f.write(f"Data split: {data_split}\n") f.write(f"Original path: {original_sample_dir}\n") f.write(f"Copied at: {datetime.now().isoformat()}\n\n") f.write("Files included:\n") for item in sorted(original_data_dir.rglob("*")): if item.is_file() and item.name != "file_manifest.txt": f.write(f"- {item.relative_to(original_data_dir)}\n") except Exception as e: print(f"Error saving original data for sample {sample_idx}: {e}") import traceback traceback.print_exc()