FangSen9000's picture
Upload nano_WaveGen
8e263cf verified
# configs/default.yaml
# ==================== 核心配置 ====================
# 训练配置
training:
batch_size: 24 # Reduced for faster first step (8 per GPU)
# 基于步数的训练(v30新增)
max_steps: 50000 # 总训练步数(可以根据需要调整)
# 时序预测配置
max_history_frames: 3 # 最大历史帧数(1-3帧)
bidirectional_training: true # 双向预测训练(从中间帧开始)
# max_prediction_frames: 1 # <deprecated> 由模型覆盖整段序列,这里保留用于兼容
initial_teacher_forcing_prob: 0.5 # 初始锚点帧教师强制概率
initial_frame:
strategy: "middle" # 可选: middle | fixed | random
offset: 0 # 在基础策略上的偏移
random_history_sampling: true # 是否在推理时随机选择0-历史帧数量作为参考帧
freeze_static_from_anchor: true # 是否在预测序列时固定0-5维的静态参数沿用锚点帧
multi_sample_attempts: 5 # 每个样本的随机尝试次数(仅训练时启用)
decoder_noise_std: 0.2 # 解码阶段加入的高斯噪声标准差(0表示禁用)
frame_rate: 8.0 # MOVi 样本在预处理阶段统一为 8fps,速度计算与训练假设保持一致
# 调试和日志配置
debug_print_interval: 1 # 每多少步打印一次调试信息(损失值等)
log_interval: 50 # 每多少步记录一次日志
learning_rate: 0.001 # Standard T5 learning rate
gradient_clip_val: 1.0 # Gradient clipping value
# GPU配置
use_multi_gpu: true # 启用多GPU训练
gpu_list: [0, 1, 2, 3, 4, 5, 6, 7] # 可用GPU列表
use_free_gpus: true # 自动选择空闲GPU
# 保存训练的模型和生成结果 配置
evaluation:
max_batches: 0 # 例如0禁用验证,只比较训练 loss。改成 None 就是验证全量;任何正整数则限制评估批次数。
save_generation:
enabled: true #保存训练的模型
save_gt: true # 是否保存GT数据
fixed_samples: 5 # 固定样本数量,用于对比
save_interval: 100 # 每100步保存一次
save_dir: "core_space" # 保存目录
# ==================== Text2Wave 配置 ====================
# 模型设置
text2wave_model:
# 原始模型: google/long-t5-tglobal-base
model_name: "google/t5-v1_1-small"
# 损失函数配置
loss:
# 损失权重
weights:
wave_loss: 4.0 # 波损失(超二次元参数)权重
wave_contrastive_loss: 2.0 # 序列级对比损失权重
world_info_loss: 0.5 # 世界信息损失(相机,缩放,时间)权重
controllable_info_loss: 0.1 # 可控制信息损失(质量,摩擦,弹性)权重
pla_loss: 3.0 # 最小作用量约束损失权重
wave_contrastive:
temperature: 0.2 # 对比分布温度
# 数据配置
data:
# MOVi数据集配置
num_workers: 32 # 数据加载线程数
max_sequences: 100 # 最大序列数,-1表示使用所有数据,设置较小值用于快速测试
physics:
gravity: 9.81 # 自由落体重力加速度(单位:m/s^2)
collision_buffer: 1.05 # 判定碰撞时的半径放大系数
# ==================== Wave2Pixel 配置 ====================
# 网格配置
grid:
size: 64 # 3D网格分辨率
prob_threshold: 0.5
# 世界坐标系配置
world_coordinate_system:
enabled: true # 是否启用世界坐标系
world_scale: 10.0 # 世界坐标范围 ±10米
voxel_size: 0.05 # 体素大小 5cm
near_plane: 0.1 # 近平面距离
far_plane: 50.0 # 远平面距离
predict_world_scale: true # 让模型预测世界缩放比例
world_scale_loss_weight: 0.1 # 世界缩放比例损失权重 - 增加到1.0以加快学习
# 相机配置
camera:
default_view: "front"
fov: 60
near: 0.1
far: 100.0
# 世界坐标系中的相机位置
views:
front: [0, 0, 2]
back: [0, 0, -2]
left: [-2, 0, 0]
right: [2, 0, 0]
top: [0, 2, 0]
bottom: [0, -2, 0]
# 相机旋转角度 (pitch, yaw, roll)
view_rotations:
front: [0, 0, 0]
back: [0, 3.14159, 0] # 180度旋转
left: [0, -1.5708, 0] # -90度旋转
right: [0, 1.5708, 0] # 90度旋转
top: [-1.5708, 0, 0] # -90度俯视
bottom: [1.5708, 0, 0] # 90度仰视
# 生成配置
generation:
mode: "image" # "image" 或 "video"
time:
start: 0.0
end: 12.0
fps: 30 # 外部观察频率
timestep: 0.0 # 用于单帧图像生成
compute_wsf: false # 是否默认计算完整WSF场
output_dir: "core_space" # 默认输出目录
# 输出格式配置
output:
format: "triple_channel" # 可选: "complex", "dual_channel", "triple_channel"
third_channel: "amplitude" # 如果format为"triple_channel",第三通道的内容: "amplitude", "phase", "none"
# Wave2Pixel相关的模型组件
model:
wave_encoder:
hidden_dim: 256
dropout: 0.1
feature_extractor:
input_dim: 4 # 实部、虚部、振幅、相位
hidden_dim: 64
output_dim: 32
dropout: 0.1
# 重命名为pixel_net以匹配代码中的使用
pixel_net:
channels: [32, 64, 128, 64, 4] # 最后4通道: RGB + 概率
kernel_size: 3
padding: 1
dropout: 0.1