|
|
""" |
|
|
保存生成结果的工具函数 |
|
|
用于WaveGen v33 - 使用超二次元函数的版本 |
|
|
""" |
|
|
import os |
|
|
import numpy as np |
|
|
import torch |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Optional |
|
|
import json |
|
|
from datetime import datetime |
|
|
import shutil |
|
|
|
|
|
|
|
|
def save_generation_results( |
|
|
predictions: List[Dict], |
|
|
targets: Dict[str, torch.Tensor], |
|
|
texts: List[str], |
|
|
step: int, |
|
|
output_dir: str = "outputs", |
|
|
save_config: Dict = None, |
|
|
metadata: Dict = None, |
|
|
batch_data: Dict = None, |
|
|
data_root: str = None, |
|
|
data_split: str = "validation" |
|
|
): |
|
|
""" |
|
|
保存生成结果用于可视化和分析(增强版) |
|
|
|
|
|
Args: |
|
|
predictions: 模型预测结果列表,每个元素包含 'objects', 'world', 'physics' |
|
|
targets: 真实目标数据 |
|
|
texts: 输入文本描述 |
|
|
step: 当前训练步数 |
|
|
output_dir: 输出目录 |
|
|
save_config: 保存配置 |
|
|
metadata: 额外的元数据(如序列名称、相机参数等) |
|
|
batch_data: 完整的批次数据,用于获取更多原始信息 |
|
|
data_root: 原始数据根目录,用于复制原始文件 |
|
|
data_split: 数据集split('train' 或 'validation'),用于确定原始数据的正确位置 |
|
|
""" |
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
save_path = Path(output_dir) / f"{timestamp}_step{step}_text2wave" |
|
|
save_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
if save_config is not None: |
|
|
with open(save_path / "save_config.json", 'w') as f: |
|
|
json.dump(save_config, f, indent=2) |
|
|
|
|
|
batch_size = len(texts) |
|
|
num_frames = len(predictions) |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print(f"💾 保存生成结果: step {step}") |
|
|
print(f"📁 保存路径: {save_path}") |
|
|
print(f"📊 样本数: {batch_size}, 帧数: {num_frames}") |
|
|
print(f"{'='*60}\n") |
|
|
|
|
|
for i in range(batch_size): |
|
|
try: |
|
|
sample_dir = save_path / f"sample_{i}" |
|
|
sample_dir.mkdir(exist_ok=True) |
|
|
print(f"正在保存 sample_{i}... ", end='', flush=True) |
|
|
|
|
|
|
|
|
with open(sample_dir / "info.txt", 'w') as f: |
|
|
f.write(f"Text: {texts[i]}\n") |
|
|
f.write(f"Generated at step: {step}\n") |
|
|
f.write(f"Number of frames: {num_frames}\n") |
|
|
|
|
|
if metadata and 'sequence_names' in metadata and metadata['sequence_names'] is not None: |
|
|
f.write(f"Sequence: {metadata['sequence_names'][i]}\n") |
|
|
|
|
|
f.write("\n--- Model Output Summary ---\n") |
|
|
f.write(f"Max objects: {predictions[0]['objects'].shape[1]}\n") |
|
|
f.write("Object parameters: 15 (exists + shape[2] + scale[3] + translation[3] + rotation[3] + velocity[3])\n") |
|
|
f.write(f"World parameters: 8 (camera_pos[3] + camera_quat[4] + scene_scale[1])\n") |
|
|
f.write(f"Physics parameters: 3 (mass + friction + restitution)\n") |
|
|
|
|
|
|
|
|
frame_predictions = [] |
|
|
physics_per_frame = [] |
|
|
for f_idx, pred in enumerate(predictions): |
|
|
|
|
|
objects_batch = pred['objects'] |
|
|
world_batch = pred['world'] |
|
|
physics_batch = pred.get('physics') |
|
|
|
|
|
objects_params = objects_batch[i].cpu().numpy() if hasattr(objects_batch, 'cpu') else objects_batch[i] |
|
|
world_params = world_batch[i].cpu().numpy() if hasattr(world_batch, 'cpu') else world_batch[i] |
|
|
if physics_batch is not None: |
|
|
physics_params = physics_batch[i].cpu().numpy() if hasattr(physics_batch, 'cpu') else physics_batch[i] |
|
|
else: |
|
|
physics_params = np.zeros((objects_params.shape[0], 3), dtype=np.float32) |
|
|
|
|
|
|
|
|
physics_per_frame.append([ |
|
|
{ |
|
|
'mass': float(phys_params[0]), |
|
|
'friction': float(phys_params[1]), |
|
|
'restitution': float(phys_params[2]), |
|
|
} |
|
|
for phys_params in physics_params |
|
|
]) |
|
|
|
|
|
|
|
|
superquadrics = [] |
|
|
for obj_idx in range(objects_params.shape[0]): |
|
|
obj_params = objects_params[obj_idx] |
|
|
phys_params = physics_params[obj_idx] |
|
|
|
|
|
superquadric = { |
|
|
'exists': bool(obj_params[0] > 0.5), |
|
|
'shape': obj_params[1:3], |
|
|
'scale': obj_params[3:6], |
|
|
'translation': obj_params[6:9], |
|
|
'rotation': obj_params[9:12], |
|
|
|
|
|
'inlier_ratio': 0.0, |
|
|
'velocity': obj_params[12:15], |
|
|
'mass': phys_params[0], |
|
|
'friction': phys_params[1], |
|
|
'restitution': phys_params[2], |
|
|
} |
|
|
superquadrics.append(superquadric) |
|
|
|
|
|
|
|
|
world_info = { |
|
|
'camera_position': world_params[0:3], |
|
|
'camera_quaternion': world_params[3:7], |
|
|
'scene_scale': float(world_params[7]), |
|
|
|
|
|
'scene_center': np.zeros(3, dtype=np.float32), |
|
|
} |
|
|
|
|
|
frame_data = { |
|
|
'frame_idx': f_idx, |
|
|
'superquadrics': superquadrics, |
|
|
'world_info': world_info, |
|
|
} |
|
|
frame_predictions.append(frame_data) |
|
|
|
|
|
np.savez(sample_dir / "predictions.npz", |
|
|
text=texts[i], |
|
|
frames=frame_predictions, |
|
|
num_frames=num_frames, |
|
|
|
|
|
physics=physics_per_frame if physics_per_frame else None, |
|
|
sequence_name=metadata['sequence_names'][i] if (metadata and 'sequence_names' in metadata and metadata['sequence_names'] is not None) else "unknown", |
|
|
description="Predicted superquadric parameters for each frame") |
|
|
|
|
|
|
|
|
if targets is not None: |
|
|
|
|
|
target_objects = targets['objects'][i].cpu().numpy() if hasattr(targets['objects'], 'cpu') else targets['objects'][i] |
|
|
target_world = targets['world'][i].cpu().numpy() if hasattr(targets['world'], 'cpu') else targets['world'][i] |
|
|
|
|
|
if 'physics' in targets and targets['physics'] is not None: |
|
|
target_physics = targets['physics'][i].cpu().numpy() if hasattr(targets['physics'], 'cpu') else targets['physics'][i] |
|
|
else: |
|
|
target_physics = None |
|
|
|
|
|
|
|
|
target_physics_top = None |
|
|
if target_physics is not None: |
|
|
target_physics_top = [ |
|
|
{ |
|
|
'mass': float(p[0]), |
|
|
'friction': float(p[1]), |
|
|
'restitution': float(p[2]), |
|
|
} |
|
|
for p in target_physics |
|
|
] |
|
|
|
|
|
|
|
|
target_frames = [] |
|
|
for f_idx in range(target_objects.shape[0]): |
|
|
frame_objects = target_objects[f_idx] |
|
|
frame_world = target_world[f_idx] |
|
|
|
|
|
|
|
|
superquadrics = [] |
|
|
for obj_idx in range(frame_objects.shape[0]): |
|
|
obj_params = frame_objects[obj_idx] |
|
|
phys_params = target_physics[obj_idx] if target_physics is not None else np.zeros(3) |
|
|
|
|
|
superquadric = { |
|
|
'exists': bool(obj_params[0] > 0.5), |
|
|
'shape': obj_params[1:3], |
|
|
'scale': obj_params[3:6], |
|
|
'translation': obj_params[6:9], |
|
|
'rotation': obj_params[9:12], |
|
|
'inlier_ratio': float(obj_params[12]), |
|
|
'velocity': obj_params[13:16], |
|
|
'mass': phys_params[0], |
|
|
'friction': phys_params[1], |
|
|
'restitution': phys_params[2], |
|
|
} |
|
|
superquadrics.append(superquadric) |
|
|
|
|
|
|
|
|
world_info = { |
|
|
'camera_position': frame_world[0:3], |
|
|
'camera_quaternion': frame_world[3:7], |
|
|
'scene_scale': float(frame_world[7]), |
|
|
'scene_center': frame_world[8:11], |
|
|
} |
|
|
|
|
|
frame_data = { |
|
|
'frame_idx': f_idx, |
|
|
'superquadrics': superquadrics, |
|
|
'world_info': world_info, |
|
|
} |
|
|
target_frames.append(frame_data) |
|
|
|
|
|
|
|
|
np.savez(sample_dir / "targets.npz", |
|
|
text=texts[i], |
|
|
frames=target_frames, |
|
|
num_frames=num_frames, |
|
|
physics=target_physics_top if target_physics_top is not None else None, |
|
|
sequence_name=metadata['sequence_names'][i] if (metadata and 'sequence_names' in metadata and metadata['sequence_names'] is not None) else "unknown", |
|
|
description="Ground truth superquadric parameters for each frame") |
|
|
|
|
|
|
|
|
target_data_legacy = { |
|
|
'objects': target_objects, |
|
|
'world': target_world, |
|
|
'physics': target_physics, |
|
|
} |
|
|
|
|
|
|
|
|
save_error_statistics(frame_predictions, target_data_legacy, sample_dir) |
|
|
|
|
|
|
|
|
if metadata and 'camera_data' in metadata: |
|
|
camera_data = metadata['camera_data'][i] |
|
|
np.savez(sample_dir / "camera_params.npz", |
|
|
**camera_data) |
|
|
|
|
|
|
|
|
if batch_data is not None and data_root is not None: |
|
|
save_original_data(sample_dir, i, batch_data, metadata, data_root, data_split) |
|
|
|
|
|
print("✅") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ 错误: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
save_batch_statistics(predictions, targets, save_path) |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print(f"✅ 保存完成!") |
|
|
print(f"📁 保存路径: {save_path}") |
|
|
print(f"{'='*60}\n") |
|
|
|
|
|
return save_path |
|
|
|
|
|
|
|
|
def save_error_statistics(predictions: List[Dict], targets: Dict, save_dir: Path): |
|
|
"""计算并保存预测误差统计 |
|
|
|
|
|
Args: |
|
|
predictions: 新格式的帧列表 (包含 superquadrics 和 world_info) |
|
|
targets: 旧格式的目标数据 (包含 objects, world 数组) |
|
|
""" |
|
|
stats = {} |
|
|
|
|
|
|
|
|
object_errors = [] |
|
|
world_errors = [] |
|
|
|
|
|
for frame in predictions: |
|
|
frame_idx = frame['frame_idx'] |
|
|
|
|
|
|
|
|
superquadrics = frame['superquadrics'] |
|
|
pred_objects = [] |
|
|
for sq in superquadrics: |
|
|
obj_params = np.zeros(15, dtype=np.float32) |
|
|
obj_params[0] = 1.0 if sq['exists'] else 0.0 |
|
|
obj_params[1:3] = sq['shape'] |
|
|
obj_params[3:6] = sq['scale'] |
|
|
obj_params[6:9] = sq['translation'] |
|
|
obj_params[9:12] = sq['rotation'] |
|
|
obj_params[12:15] = sq['velocity'] |
|
|
pred_objects.append(obj_params) |
|
|
pred_obj = np.array(pred_objects) |
|
|
|
|
|
|
|
|
target_obj_full = targets['objects'][frame_idx] |
|
|
target_obj = target_obj_full[:, :pred_obj.shape[1]] |
|
|
|
|
|
|
|
|
exists_mask = target_obj[:, 0] > 0.5 |
|
|
if exists_mask.any(): |
|
|
error = np.mean(np.abs(pred_obj[exists_mask] - target_obj[exists_mask])) |
|
|
object_errors.append(error) |
|
|
|
|
|
|
|
|
world_info = frame['world_info'] |
|
|
pred_world = np.concatenate([ |
|
|
world_info['camera_position'], |
|
|
world_info['camera_quaternion'], |
|
|
[world_info['scene_scale']] |
|
|
]) |
|
|
|
|
|
|
|
|
target_world = targets['world'][frame_idx][:8] |
|
|
error = np.mean(np.abs(pred_world - target_world)) |
|
|
world_errors.append(error) |
|
|
|
|
|
stats['object_mae'] = float(np.mean(object_errors)) if object_errors else 0.0 |
|
|
stats['world_mae'] = float(np.mean(world_errors)) |
|
|
|
|
|
|
|
|
with open(save_dir / "error_statistics.json", 'w') as f: |
|
|
json.dump(stats, f, indent=2) |
|
|
|
|
|
|
|
|
def save_batch_statistics(predictions: List[Dict], targets: Dict, save_dir: Path): |
|
|
"""保存整批数据的统计信息""" |
|
|
batch_size = predictions[0]['objects'].shape[0] |
|
|
|
|
|
stats = { |
|
|
'batch_size': batch_size, |
|
|
'num_frames': len(predictions), |
|
|
'timestamp': datetime.now().isoformat(), |
|
|
} |
|
|
|
|
|
|
|
|
if targets is not None: |
|
|
objects_per_frame = [] |
|
|
for f_idx in range(len(predictions)): |
|
|
frame_objects = [] |
|
|
for b_idx in range(batch_size): |
|
|
exists = targets['objects'][b_idx, f_idx, :, 0] > 0.5 |
|
|
frame_objects.append(int(exists.sum())) |
|
|
objects_per_frame.append({ |
|
|
'frame': f_idx, |
|
|
'mean_objects': float(np.mean(frame_objects)), |
|
|
'max_objects': int(max(frame_objects)), |
|
|
'min_objects': int(min(frame_objects)), |
|
|
}) |
|
|
stats['objects_per_frame'] = objects_per_frame |
|
|
|
|
|
with open(save_dir / "batch_statistics.json", 'w') as f: |
|
|
json.dump(stats, f, indent=2) |
|
|
|
|
|
|
|
|
def save_visualization_script(save_dir: Path): |
|
|
"""保存用于可视化超二次元函数的Python脚本""" |
|
|
script = '''#!/usr/bin/env python3 |
|
|
""" |
|
|
可视化生成的超二次元函数参数 |
|
|
""" |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
from mpl_toolkits.mplot3d import Axes3D |
|
|
|
|
|
def superquadric_surface(a, b, c, e1, e2, n=50): |
|
|
"""生成超二次元函数表面点""" |
|
|
eta = np.linspace(-np.pi/2, np.pi/2, n) |
|
|
omega = np.linspace(-np.pi, np.pi, n) |
|
|
eta, omega = np.meshgrid(eta, omega) |
|
|
|
|
|
x = a * np.sign(np.cos(eta)) * np.abs(np.cos(eta))**e1 * np.sign(np.cos(omega)) * np.abs(np.cos(omega))**e2 |
|
|
y = b * np.sign(np.cos(eta)) * np.abs(np.cos(eta))**e1 * np.sign(np.sin(omega)) * np.abs(np.sin(omega))**e2 |
|
|
z = c * np.sign(np.sin(eta)) * np.abs(np.sin(eta))**e1 |
|
|
|
|
|
return x, y, z |
|
|
|
|
|
# 加载预测数据 |
|
|
data = np.load('predictions.npz', allow_pickle=True) |
|
|
frames = data['frames'] |
|
|
|
|
|
# 可视化第一帧 |
|
|
frame = frames[0] |
|
|
objects = frame['objects'] # [max_objects, 12] |
|
|
|
|
|
fig = plt.figure(figsize=(12, 8)) |
|
|
ax = fig.add_subplot(111, projection='3d') |
|
|
|
|
|
# 绘制每个存在的物体 |
|
|
for obj_idx, obj_params in enumerate(objects): |
|
|
if obj_params[0] > 0.5: # 物体存在 |
|
|
# 提取参数 |
|
|
shape_params = obj_params[1:3] # e1, e2 |
|
|
scale = obj_params[3:6] # a, b, c |
|
|
translation = obj_params[6:9] |
|
|
rotation = obj_params[9:12] # 简化处理,暂不应用旋转 |
|
|
|
|
|
# 生成表面 |
|
|
x, y, z = superquadric_surface(scale[0], scale[1], scale[2], |
|
|
shape_params[0], shape_params[1]) |
|
|
|
|
|
# 应用平移 |
|
|
x += translation[0] |
|
|
y += translation[1] |
|
|
z += translation[2] |
|
|
|
|
|
# 绘制 |
|
|
ax.plot_surface(x, y, z, alpha=0.7, label=f'Object {obj_idx}') |
|
|
|
|
|
ax.set_xlabel('X') |
|
|
ax.set_ylabel('Y') |
|
|
ax.set_zlabel('Z') |
|
|
ax.set_title('Generated Superquadric Objects') |
|
|
plt.savefig('visualization.png') |
|
|
plt.show() |
|
|
''' |
|
|
|
|
|
with open(save_dir / "visualize.py", 'w') as f: |
|
|
f.write(script) |
|
|
|
|
|
|
|
|
os.chmod(save_dir / "visualize.py", 0o755) |
|
|
|
|
|
|
|
|
def save_original_data(sample_dir: Path, sample_idx: int, batch_data: Dict, metadata: Dict, data_root: str, data_split: str = "validation"): |
|
|
""" |
|
|
保存原始数据文件,包括RGB图像、深度图、分割图、点云等 |
|
|
|
|
|
Args: |
|
|
sample_dir: 当前样本保存目录 |
|
|
sample_idx: 批次中的样本索引 |
|
|
batch_data: 批次数据 |
|
|
metadata: 元数据 |
|
|
data_root: MOVi数据集根目录 |
|
|
data_split: 数据集split('train' 或 'validation') |
|
|
""" |
|
|
try: |
|
|
|
|
|
sequence_name = None |
|
|
if metadata and 'sequence_names' in metadata and metadata['sequence_names'] is not None: |
|
|
sequence_name = metadata['sequence_names'][sample_idx] |
|
|
|
|
|
if not sequence_name: |
|
|
print(f"Warning: No sequence name found for sample {sample_idx}") |
|
|
|
|
|
no_original_data_file = sample_dir / "no_original_data.txt" |
|
|
with open(no_original_data_file, 'w') as f: |
|
|
f.write("原始数据未保存,因为无法获取序列名称。\n") |
|
|
f.write("这可能是因为数据加载器没有返回sequence_names字段。\n") |
|
|
f.write(f"Sample index: {sample_idx}\n") |
|
|
f.write(f"Generated at step: {datetime.now().isoformat()}\n") |
|
|
return |
|
|
|
|
|
|
|
|
data_root_path = Path(data_root) |
|
|
|
|
|
|
|
|
original_sample_dir = data_root_path / data_split / sequence_name |
|
|
|
|
|
if not original_sample_dir.exists(): |
|
|
print(f"Warning: Could not find original data for {sequence_name} in {data_split} split") |
|
|
|
|
|
error_file = sample_dir / "original_data_not_found.txt" |
|
|
with open(error_file, 'w') as f: |
|
|
f.write(f"原始数据未找到\n") |
|
|
f.write(f"查找路径: {original_sample_dir}\n") |
|
|
f.write(f"数据集: {data_split}\n") |
|
|
f.write(f"序列名: {sequence_name}\n") |
|
|
f.write(f"时间: {datetime.now().isoformat()}\n") |
|
|
return |
|
|
|
|
|
print(f"Copying original data from {original_sample_dir}") |
|
|
|
|
|
|
|
|
original_data_dir = sample_dir / "original_data" |
|
|
original_data_dir.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
rgb_dir = original_data_dir / "rgb" |
|
|
rgb_dir.mkdir(exist_ok=True) |
|
|
original_rgb_dir = original_sample_dir / "rgb" |
|
|
if original_rgb_dir.exists(): |
|
|
for rgb_file in sorted(original_rgb_dir.glob("frame_*.png")): |
|
|
shutil.copy2(rgb_file, rgb_dir / rgb_file.name) |
|
|
|
|
|
|
|
|
depth_dir = original_data_dir / "depth" |
|
|
depth_dir.mkdir(exist_ok=True) |
|
|
original_depth_dir = original_sample_dir / "depth" |
|
|
if original_depth_dir.exists(): |
|
|
|
|
|
merged_depth = original_depth_dir / "depth_merge.npz" |
|
|
if merged_depth.exists(): |
|
|
shutil.copy2(merged_depth, depth_dir / "depth_merge.npz") |
|
|
else: |
|
|
|
|
|
for depth_file in sorted(original_depth_dir.glob("frame_*.npy")): |
|
|
shutil.copy2(depth_file, depth_dir / depth_file.name) |
|
|
|
|
|
|
|
|
seg_dir = original_data_dir / "segmentation" |
|
|
seg_dir.mkdir(exist_ok=True) |
|
|
original_seg_dir = original_sample_dir / "segmentation" |
|
|
if original_seg_dir.exists(): |
|
|
|
|
|
merged_seg = original_seg_dir / "segmentation_merge.npz" |
|
|
if merged_seg.exists(): |
|
|
shutil.copy2(merged_seg, seg_dir / "segmentation_merge.npz") |
|
|
else: |
|
|
|
|
|
for seg_file in sorted(original_seg_dir.glob("frame_*.npy")): |
|
|
shutil.copy2(seg_file, seg_dir / seg_file.name) |
|
|
|
|
|
|
|
|
normal_dir = original_data_dir / "normal" |
|
|
normal_dir.mkdir(exist_ok=True) |
|
|
original_normal_dir = original_sample_dir / "normal" |
|
|
if original_normal_dir.exists(): |
|
|
|
|
|
merged_normal = original_normal_dir / "normal_merge.npz" |
|
|
if merged_normal.exists(): |
|
|
shutil.copy2(merged_normal, normal_dir / "normal_merge.npz") |
|
|
else: |
|
|
|
|
|
for normal_file in sorted(original_normal_dir.glob("frame_*.npy")): |
|
|
shutil.copy2(normal_file, normal_dir / normal_file.name) |
|
|
|
|
|
|
|
|
camera_traj_file = original_sample_dir / "camera_trajectory.npz" |
|
|
if camera_traj_file.exists(): |
|
|
shutil.copy2(camera_traj_file, original_data_dir / "camera_trajectory.npz") |
|
|
|
|
|
|
|
|
metadata_file = original_sample_dir / "metadata.json" |
|
|
if metadata_file.exists(): |
|
|
shutil.copy2(metadata_file, original_data_dir / "metadata.json") |
|
|
|
|
|
|
|
|
full_cache_file = original_sample_dir / "Full_Sample_Data_for_Learning_Target.npz" |
|
|
if full_cache_file.exists(): |
|
|
shutil.copy2(full_cache_file, original_data_dir / "Full_Sample_Data_for_Learning_Target.npz") |
|
|
|
|
|
|
|
|
for folder_name in ['object_coordinates', 'point_clouds']: |
|
|
folder_dir = original_data_dir / folder_name |
|
|
original_folder_dir = original_sample_dir / folder_name |
|
|
if original_folder_dir.exists(): |
|
|
folder_dir.mkdir(exist_ok=True) |
|
|
|
|
|
merged_file = original_folder_dir / f"{folder_name}_merge.npz" |
|
|
if merged_file.exists(): |
|
|
shutil.copy2(merged_file, folder_dir / f"{folder_name}_merge.npz") |
|
|
else: |
|
|
|
|
|
for npy_file in sorted(original_folder_dir.glob("frame_*.npy")): |
|
|
shutil.copy2(npy_file, folder_dir / npy_file.name) |
|
|
|
|
|
|
|
|
if 'point_clouds' in batch_data: |
|
|
pc_data = batch_data['point_clouds'][sample_idx] |
|
|
np.savez_compressed(original_data_dir / "point_clouds.npz", **pc_data) |
|
|
|
|
|
|
|
|
if 'scene_normalization' in batch_data: |
|
|
norm_params = batch_data['scene_normalization'][sample_idx] |
|
|
with open(original_data_dir / "scene_normalization.json", 'w') as f: |
|
|
json.dump({ |
|
|
'scene_center': norm_params['center'].tolist() if hasattr(norm_params['center'], 'tolist') else norm_params['center'], |
|
|
'scene_scale': float(norm_params['scale']) if 'scale' in norm_params else 1.0, |
|
|
'scene_extent': float(norm_params['extent']) if 'extent' in norm_params else 1.0 |
|
|
}, f, indent=2) |
|
|
|
|
|
|
|
|
with open(original_data_dir / "file_manifest.txt", 'w') as f: |
|
|
f.write(f"Original sequence: {sequence_name}\n") |
|
|
f.write(f"Data split: {data_split}\n") |
|
|
f.write(f"Original path: {original_sample_dir}\n") |
|
|
f.write(f"Copied at: {datetime.now().isoformat()}\n\n") |
|
|
f.write("Files included:\n") |
|
|
for item in sorted(original_data_dir.rglob("*")): |
|
|
if item.is_file() and item.name != "file_manifest.txt": |
|
|
f.write(f"- {item.relative_to(original_data_dir)}\n") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error saving original data for sample {sample_idx}: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|