WaveGen / nano_WaveGen /utils /save_generation_results.py
FangSen9000's picture
Upload nano_WaveGen
8e263cf verified
"""
保存生成结果的工具函数
用于WaveGen v33 - 使用超二次元函数的版本
"""
import os
import numpy as np
import torch
from pathlib import Path
from typing import Dict, List, Optional
import json
from datetime import datetime
import shutil
def save_generation_results(
predictions: List[Dict],
targets: Dict[str, torch.Tensor],
texts: List[str],
step: int,
output_dir: str = "outputs",
save_config: Dict = None,
metadata: Dict = None,
batch_data: Dict = None, # 新增:完整的批次数据
data_root: str = None, # 新增:原始数据根目录
data_split: str = "validation" # 新增:数据集split(train/validation)
):
"""
保存生成结果用于可视化和分析(增强版)
Args:
predictions: 模型预测结果列表,每个元素包含 'objects', 'world', 'physics'
targets: 真实目标数据
texts: 输入文本描述
step: 当前训练步数
output_dir: 输出目录
save_config: 保存配置
metadata: 额外的元数据(如序列名称、相机参数等)
batch_data: 完整的批次数据,用于获取更多原始信息
data_root: 原始数据根目录,用于复制原始文件
data_split: 数据集split('train' 或 'validation'),用于确定原始数据的正确位置
"""
# 创建时间戳目录
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
save_path = Path(output_dir) / f"{timestamp}_step{step}_text2wave"
save_path.mkdir(parents=True, exist_ok=True)
# 保存配置
if save_config is not None:
with open(save_path / "save_config.json", 'w') as f:
json.dump(save_config, f, indent=2)
batch_size = len(texts)
num_frames = len(predictions)
print(f"\n{'='*60}")
print(f"💾 保存生成结果: step {step}")
print(f"📁 保存路径: {save_path}")
print(f"📊 样本数: {batch_size}, 帧数: {num_frames}")
print(f"{'='*60}\n")
for i in range(batch_size):
try:
sample_dir = save_path / f"sample_{i}"
sample_dir.mkdir(exist_ok=True)
print(f"正在保存 sample_{i}... ", end='', flush=True)
# 1. 保存文本描述和元信息
with open(sample_dir / "info.txt", 'w') as f:
f.write(f"Text: {texts[i]}\n")
f.write(f"Generated at step: {step}\n")
f.write(f"Number of frames: {num_frames}\n")
if metadata and 'sequence_names' in metadata and metadata['sequence_names'] is not None:
f.write(f"Sequence: {metadata['sequence_names'][i]}\n")
f.write("\n--- Model Output Summary ---\n")
f.write(f"Max objects: {predictions[0]['objects'].shape[1]}\n")
f.write("Object parameters: 15 (exists + shape[2] + scale[3] + translation[3] + rotation[3] + velocity[3])\n")
f.write(f"World parameters: 8 (camera_pos[3] + camera_quat[4] + scene_scale[1])\n")
f.write(f"Physics parameters: 3 (mass + friction + restitution)\n")
# 2. 保存生成的超二次元函数参数(改进格式)
frame_predictions = []
physics_per_frame = [] # 保存每一帧的physics预测,便于查看全序列
for f_idx, pred in enumerate(predictions):
# predictions列表中每个元素包含当前帧、完整batch的数据,这里先取出当前样本
objects_batch = pred['objects']
world_batch = pred['world']
physics_batch = pred.get('physics')
objects_params = objects_batch[i].cpu().numpy() if hasattr(objects_batch, 'cpu') else objects_batch[i] # [max_objects, 15]
world_params = world_batch[i].cpu().numpy() if hasattr(world_batch, 'cpu') else world_batch[i] # [8]
if physics_batch is not None:
physics_params = physics_batch[i].cpu().numpy() if hasattr(physics_batch, 'cpu') else physics_batch[i] # [max_objects, 3]
else:
physics_params = np.zeros((objects_params.shape[0], 3), dtype=np.float32)
# 保存当前帧的 physics(每帧各物体的mass/friction/restitution)
physics_per_frame.append([
{
'mass': float(phys_params[0]),
'friction': float(phys_params[1]),
'restitution': float(phys_params[2]),
}
for phys_params in physics_params
])
# 将物体参数转换为更易读的格式
superquadrics = []
for obj_idx in range(objects_params.shape[0]):
obj_params = objects_params[obj_idx]
phys_params = physics_params[obj_idx]
superquadric = {
'exists': bool(obj_params[0] > 0.5), # exists flag
'shape': obj_params[1:3], # epsilon1, epsilon2
'scale': obj_params[3:6], # a, b, c
'translation': obj_params[6:9], # x, y, z
'rotation': obj_params[9:12], # euler angles
# 预测没有 inlier_ratio,填充0以保持键一致
'inlier_ratio': 0.0,
'velocity': obj_params[12:15], # vx, vy, vz
'mass': phys_params[0],
'friction': phys_params[1],
'restitution': phys_params[2],
}
superquadrics.append(superquadric)
# 将世界参数转换为更易读的格式
world_info = {
'camera_position': world_params[0:3], # x, y, z
'camera_quaternion': world_params[3:7], # w, x, y, z
'scene_scale': float(world_params[7]), # scale
# 预测没有scene_center,填零保持字段一致,方便下游读取
'scene_center': np.zeros(3, dtype=np.float32),
}
frame_data = {
'frame_idx': f_idx,
'superquadrics': superquadrics,
'world_info': world_info,
}
frame_predictions.append(frame_data)
np.savez(sample_dir / "predictions.npz",
text=texts[i],
frames=frame_predictions,
num_frames=num_frames,
# 保存全序列physics;未预测则写None
physics=physics_per_frame if physics_per_frame else None,
sequence_name=metadata['sequence_names'][i] if (metadata and 'sequence_names' in metadata and metadata['sequence_names'] is not None) else "unknown",
description="Predicted superquadric parameters for each frame")
# 3. 保存真实目标数据(如果有)- 改进格式
if targets is not None:
# targets中的数据已经是完整批次,需要索引[i]获取当前样本
target_objects = targets['objects'][i].cpu().numpy() if hasattr(targets['objects'], 'cpu') else targets['objects'][i] # [num_frames, max_objects, 16]
target_world = targets['world'][i].cpu().numpy() if hasattr(targets['world'], 'cpu') else targets['world'][i] # [num_frames, 11]
if 'physics' in targets and targets['physics'] is not None:
target_physics = targets['physics'][i].cpu().numpy() if hasattr(targets['physics'], 'cpu') else targets['physics'][i]
else:
target_physics = None
# 生成顶层 physics,与原始 Full_Sample_Data_for_Learning_Target 一致
target_physics_top = None
if target_physics is not None:
target_physics_top = [
{
'mass': float(p[0]),
'friction': float(p[1]),
'restitution': float(p[2]),
}
for p in target_physics
]
# 将目标数据转换为更易读的格式
target_frames = []
for f_idx in range(target_objects.shape[0]):
frame_objects = target_objects[f_idx] # [max_objects, 16]
frame_world = target_world[f_idx] # [11]
# 转换物体参数
superquadrics = []
for obj_idx in range(frame_objects.shape[0]):
obj_params = frame_objects[obj_idx]
phys_params = target_physics[obj_idx] if target_physics is not None else np.zeros(3)
superquadric = {
'exists': bool(obj_params[0] > 0.5), # exists flag
'shape': obj_params[1:3], # epsilon1, epsilon2
'scale': obj_params[3:6], # a, b, c
'translation': obj_params[6:9], # x, y, z
'rotation': obj_params[9:12], # euler angles
'inlier_ratio': float(obj_params[12]), # GT specific: inlier ratio
'velocity': obj_params[13:16], # vx, vy, vz
'mass': phys_params[0],
'friction': phys_params[1],
'restitution': phys_params[2],
}
superquadrics.append(superquadric)
# 转换世界参数
world_info = {
'camera_position': frame_world[0:3], # x, y, z
'camera_quaternion': frame_world[3:7], # w, x, y, z
'scene_scale': float(frame_world[7]), # scale
'scene_center': frame_world[8:11], # center x, y, z
}
frame_data = {
'frame_idx': f_idx,
'superquadrics': superquadrics,
'world_info': world_info,
}
target_frames.append(frame_data)
# 保存改进格式的 targets.npz
np.savez(sample_dir / "targets.npz",
text=texts[i],
frames=target_frames,
num_frames=num_frames,
physics=target_physics_top if target_physics_top is not None else None,
sequence_name=metadata['sequence_names'][i] if (metadata and 'sequence_names' in metadata and metadata['sequence_names'] is not None) else "unknown",
description="Ground truth superquadric parameters for each frame")
# 为了兼容性,也保存原始格式(用于误差计算)
target_data_legacy = {
'objects': target_objects,
'world': target_world,
'physics': target_physics,
}
# 计算并保存误差统计(使用旧格式)
save_error_statistics(frame_predictions, target_data_legacy, sample_dir)
# 4. 保存相机参数(如果有)
if metadata and 'camera_data' in metadata:
camera_data = metadata['camera_data'][i]
np.savez(sample_dir / "camera_params.npz",
**camera_data)
# 5. 保存原始数据(新增功能,不再依赖camera_data)
if batch_data is not None and data_root is not None:
save_original_data(sample_dir, i, batch_data, metadata, data_root, data_split)
print("✅") # 完成标记
except Exception as e:
print(f"❌ 错误: {e}")
import traceback
traceback.print_exc()
continue # 继续保存其他样本
# 注意:已移除save_visualization_script,因为不需要单独的可视化脚本
# 保存整体统计信息
save_batch_statistics(predictions, targets, save_path)
print(f"\n{'='*60}")
print(f"✅ 保存完成!")
print(f"📁 保存路径: {save_path}")
print(f"{'='*60}\n")
return save_path
def save_error_statistics(predictions: List[Dict], targets: Dict, save_dir: Path):
"""计算并保存预测误差统计
Args:
predictions: 新格式的帧列表 (包含 superquadrics 和 world_info)
targets: 旧格式的目标数据 (包含 objects, world 数组)
"""
stats = {}
# 将新格式的 predictions 转换回数组格式进行误差计算
object_errors = []
world_errors = []
for frame in predictions:
frame_idx = frame['frame_idx']
# 从新格式重建物体数组
superquadrics = frame['superquadrics']
pred_objects = []
for sq in superquadrics:
obj_params = np.zeros(15, dtype=np.float32)
obj_params[0] = 1.0 if sq['exists'] else 0.0
obj_params[1:3] = sq['shape']
obj_params[3:6] = sq['scale']
obj_params[6:9] = sq['translation']
obj_params[9:12] = sq['rotation']
obj_params[12:15] = sq['velocity']
pred_objects.append(obj_params)
pred_obj = np.array(pred_objects)
# 获取目标物体数据
target_obj_full = targets['objects'][frame_idx]
target_obj = target_obj_full[:, :pred_obj.shape[1]] # 对齐模型预测的维度
# 只计算存在的物体
exists_mask = target_obj[:, 0] > 0.5
if exists_mask.any():
error = np.mean(np.abs(pred_obj[exists_mask] - target_obj[exists_mask]))
object_errors.append(error)
# 从新格式重建世界参数数组
world_info = frame['world_info']
pred_world = np.concatenate([
world_info['camera_position'],
world_info['camera_quaternion'],
[world_info['scene_scale']]
])
# 获取目标世界数据
target_world = targets['world'][frame_idx][:8]
error = np.mean(np.abs(pred_world - target_world))
world_errors.append(error)
stats['object_mae'] = float(np.mean(object_errors)) if object_errors else 0.0
stats['world_mae'] = float(np.mean(world_errors))
# 保存统计信息
with open(save_dir / "error_statistics.json", 'w') as f:
json.dump(stats, f, indent=2)
def save_batch_statistics(predictions: List[Dict], targets: Dict, save_dir: Path):
"""保存整批数据的统计信息"""
batch_size = predictions[0]['objects'].shape[0]
stats = {
'batch_size': batch_size,
'num_frames': len(predictions),
'timestamp': datetime.now().isoformat(),
}
# 统计每帧实际存在的物体数量
if targets is not None:
objects_per_frame = []
for f_idx in range(len(predictions)):
frame_objects = []
for b_idx in range(batch_size):
exists = targets['objects'][b_idx, f_idx, :, 0] > 0.5
frame_objects.append(int(exists.sum()))
objects_per_frame.append({
'frame': f_idx,
'mean_objects': float(np.mean(frame_objects)),
'max_objects': int(max(frame_objects)),
'min_objects': int(min(frame_objects)),
})
stats['objects_per_frame'] = objects_per_frame
with open(save_dir / "batch_statistics.json", 'w') as f:
json.dump(stats, f, indent=2)
def save_visualization_script(save_dir: Path):
"""保存用于可视化超二次元函数的Python脚本"""
script = '''#!/usr/bin/env python3
"""
可视化生成的超二次元函数参数
"""
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
def superquadric_surface(a, b, c, e1, e2, n=50):
"""生成超二次元函数表面点"""
eta = np.linspace(-np.pi/2, np.pi/2, n)
omega = np.linspace(-np.pi, np.pi, n)
eta, omega = np.meshgrid(eta, omega)
x = a * np.sign(np.cos(eta)) * np.abs(np.cos(eta))**e1 * np.sign(np.cos(omega)) * np.abs(np.cos(omega))**e2
y = b * np.sign(np.cos(eta)) * np.abs(np.cos(eta))**e1 * np.sign(np.sin(omega)) * np.abs(np.sin(omega))**e2
z = c * np.sign(np.sin(eta)) * np.abs(np.sin(eta))**e1
return x, y, z
# 加载预测数据
data = np.load('predictions.npz', allow_pickle=True)
frames = data['frames']
# 可视化第一帧
frame = frames[0]
objects = frame['objects'] # [max_objects, 12]
fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(111, projection='3d')
# 绘制每个存在的物体
for obj_idx, obj_params in enumerate(objects):
if obj_params[0] > 0.5: # 物体存在
# 提取参数
shape_params = obj_params[1:3] # e1, e2
scale = obj_params[3:6] # a, b, c
translation = obj_params[6:9]
rotation = obj_params[9:12] # 简化处理,暂不应用旋转
# 生成表面
x, y, z = superquadric_surface(scale[0], scale[1], scale[2],
shape_params[0], shape_params[1])
# 应用平移
x += translation[0]
y += translation[1]
z += translation[2]
# 绘制
ax.plot_surface(x, y, z, alpha=0.7, label=f'Object {obj_idx}')
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_title('Generated Superquadric Objects')
plt.savefig('visualization.png')
plt.show()
'''
with open(save_dir / "visualize.py", 'w') as f:
f.write(script)
# 使脚本可执行
os.chmod(save_dir / "visualize.py", 0o755)
def save_original_data(sample_dir: Path, sample_idx: int, batch_data: Dict, metadata: Dict, data_root: str, data_split: str = "validation"):
"""
保存原始数据文件,包括RGB图像、深度图、分割图、点云等
Args:
sample_dir: 当前样本保存目录
sample_idx: 批次中的样本索引
batch_data: 批次数据
metadata: 元数据
data_root: MOVi数据集根目录
data_split: 数据集split('train' 或 'validation')
"""
try:
# 获取原始序列名称
sequence_name = None
if metadata and 'sequence_names' in metadata and metadata['sequence_names'] is not None:
sequence_name = metadata['sequence_names'][sample_idx]
if not sequence_name:
print(f"Warning: No sequence name found for sample {sample_idx}")
# 创建一个说明文件,解释为什么没有原始数据
no_original_data_file = sample_dir / "no_original_data.txt"
with open(no_original_data_file, 'w') as f:
f.write("原始数据未保存,因为无法获取序列名称。\n")
f.write("这可能是因为数据加载器没有返回sequence_names字段。\n")
f.write(f"Sample index: {sample_idx}\n")
f.write(f"Generated at step: {datetime.now().isoformat()}\n")
return
# 构建原始数据路径
data_root_path = Path(data_root)
# 直接使用指定的data_split来定位原始数据
original_sample_dir = data_root_path / data_split / sequence_name
if not original_sample_dir.exists():
print(f"Warning: Could not find original data for {sequence_name} in {data_split} split")
# 创建一个说明文件
error_file = sample_dir / "original_data_not_found.txt"
with open(error_file, 'w') as f:
f.write(f"原始数据未找到\n")
f.write(f"查找路径: {original_sample_dir}\n")
f.write(f"数据集: {data_split}\n")
f.write(f"序列名: {sequence_name}\n")
f.write(f"时间: {datetime.now().isoformat()}\n")
return
print(f"Copying original data from {original_sample_dir}")
# 创建原始数据子目录
original_data_dir = sample_dir / "original_data"
original_data_dir.mkdir(exist_ok=True)
# 1. 复制RGB图像(所有帧)
rgb_dir = original_data_dir / "rgb"
rgb_dir.mkdir(exist_ok=True)
original_rgb_dir = original_sample_dir / "rgb"
if original_rgb_dir.exists():
for rgb_file in sorted(original_rgb_dir.glob("frame_*.png")):
shutil.copy2(rgb_file, rgb_dir / rgb_file.name)
# 2. 复制深度图(优先复制合并的npz,否则复制单独的npy)
depth_dir = original_data_dir / "depth"
depth_dir.mkdir(exist_ok=True)
original_depth_dir = original_sample_dir / "depth"
if original_depth_dir.exists():
# 检查是否有合并的npz文件
merged_depth = original_depth_dir / "depth_merge.npz"
if merged_depth.exists():
shutil.copy2(merged_depth, depth_dir / "depth_merge.npz")
else:
# 复制单独的npy文件
for depth_file in sorted(original_depth_dir.glob("frame_*.npy")):
shutil.copy2(depth_file, depth_dir / depth_file.name)
# 3. 复制分割图(优先复制合并的npz,否则复制单独的npy)
seg_dir = original_data_dir / "segmentation"
seg_dir.mkdir(exist_ok=True)
original_seg_dir = original_sample_dir / "segmentation"
if original_seg_dir.exists():
# 检查是否有合并的npz文件
merged_seg = original_seg_dir / "segmentation_merge.npz"
if merged_seg.exists():
shutil.copy2(merged_seg, seg_dir / "segmentation_merge.npz")
else:
# 复制单独的npy文件
for seg_file in sorted(original_seg_dir.glob("frame_*.npy")):
shutil.copy2(seg_file, seg_dir / seg_file.name)
# 4. 复制法线图(优先复制合并的npz,否则复制单独的npy)
normal_dir = original_data_dir / "normal"
normal_dir.mkdir(exist_ok=True)
original_normal_dir = original_sample_dir / "normal"
if original_normal_dir.exists():
# 检查是否有合并的npz文件
merged_normal = original_normal_dir / "normal_merge.npz"
if merged_normal.exists():
shutil.copy2(merged_normal, normal_dir / "normal_merge.npz")
else:
# 复制单独的npy文件
for normal_file in sorted(original_normal_dir.glob("frame_*.npy")):
shutil.copy2(normal_file, normal_dir / normal_file.name)
# 5. 复制相机轨迹
camera_traj_file = original_sample_dir / "camera_trajectory.npz"
if camera_traj_file.exists():
shutil.copy2(camera_traj_file, original_data_dir / "camera_trajectory.npz")
# 6. 复制元数据JSON
metadata_file = original_sample_dir / "metadata.json"
if metadata_file.exists():
shutil.copy2(metadata_file, original_data_dir / "metadata.json")
# 7. 复制完整的训练目标数据缓存文件(Full_Sample_Data_for_Learning_Target.npz)
full_cache_file = original_sample_dir / "Full_Sample_Data_for_Learning_Target.npz"
if full_cache_file.exists():
shutil.copy2(full_cache_file, original_data_dir / "Full_Sample_Data_for_Learning_Target.npz")
# 8. 复制其他可能的合并文件(object_coordinates, point_clouds等)
for folder_name in ['object_coordinates', 'point_clouds']:
folder_dir = original_data_dir / folder_name
original_folder_dir = original_sample_dir / folder_name
if original_folder_dir.exists():
folder_dir.mkdir(exist_ok=True)
# 检查是否有合并的npz文件
merged_file = original_folder_dir / f"{folder_name}_merge.npz"
if merged_file.exists():
shutil.copy2(merged_file, folder_dir / f"{folder_name}_merge.npz")
else:
# 复制单独的npy文件
for npy_file in sorted(original_folder_dir.glob("frame_*.npy")):
shutil.copy2(npy_file, folder_dir / npy_file.name)
# 9. 如果有预处理的点云数据(在batch_data中),也保存
if 'point_clouds' in batch_data:
pc_data = batch_data['point_clouds'][sample_idx]
np.savez_compressed(original_data_dir / "point_clouds.npz", **pc_data)
# 10. 保存场景归一化参数
if 'scene_normalization' in batch_data:
norm_params = batch_data['scene_normalization'][sample_idx]
with open(original_data_dir / "scene_normalization.json", 'w') as f:
json.dump({
'scene_center': norm_params['center'].tolist() if hasattr(norm_params['center'], 'tolist') else norm_params['center'],
'scene_scale': float(norm_params['scale']) if 'scale' in norm_params else 1.0,
'scene_extent': float(norm_params['extent']) if 'extent' in norm_params else 1.0
}, f, indent=2)
# 11. 创建文件清单
with open(original_data_dir / "file_manifest.txt", 'w') as f:
f.write(f"Original sequence: {sequence_name}\n")
f.write(f"Data split: {data_split}\n")
f.write(f"Original path: {original_sample_dir}\n")
f.write(f"Copied at: {datetime.now().isoformat()}\n\n")
f.write("Files included:\n")
for item in sorted(original_data_dir.rglob("*")):
if item.is_file() and item.name != "file_manifest.txt":
f.write(f"- {item.relative_to(original_data_dir)}\n")
except Exception as e:
print(f"Error saving original data for sample {sample_idx}: {e}")
import traceback
traceback.print_exc()