# configs/default.yaml # ==================== 核心配置 ==================== # 训练配置 training: batch_size: 24 # Reduced for faster first step (8 per GPU) # 基于步数的训练(v30新增) max_steps: 50000 # 总训练步数(可以根据需要调整) # 时序预测配置 max_history_frames: 3 # 最大历史帧数(1-3帧) bidirectional_training: true # 双向预测训练(从中间帧开始) # max_prediction_frames: 1 # 由模型覆盖整段序列,这里保留用于兼容 initial_teacher_forcing_prob: 0.5 # 初始锚点帧教师强制概率 initial_frame: strategy: "middle" # 可选: middle | fixed | random offset: 0 # 在基础策略上的偏移 random_history_sampling: true # 是否在推理时随机选择0-历史帧数量作为参考帧 freeze_static_from_anchor: true # 是否在预测序列时固定0-5维的静态参数沿用锚点帧 multi_sample_attempts: 5 # 每个样本的随机尝试次数(仅训练时启用) decoder_noise_std: 0.2 # 解码阶段加入的高斯噪声标准差(0表示禁用) frame_rate: 8.0 # MOVi 样本在预处理阶段统一为 8fps,速度计算与训练假设保持一致 # 调试和日志配置 debug_print_interval: 1 # 每多少步打印一次调试信息(损失值等) log_interval: 50 # 每多少步记录一次日志 learning_rate: 0.001 # Standard T5 learning rate gradient_clip_val: 1.0 # Gradient clipping value # GPU配置 use_multi_gpu: true # 启用多GPU训练 gpu_list: [0, 1, 2, 3, 4, 5, 6, 7] # 可用GPU列表 use_free_gpus: true # 自动选择空闲GPU # 保存训练的模型和生成结果 配置 evaluation: max_batches: 0 # 例如0禁用验证,只比较训练 loss。改成 None 就是验证全量;任何正整数则限制评估批次数。 save_generation: enabled: true #保存训练的模型 save_gt: true # 是否保存GT数据 fixed_samples: 5 # 固定样本数量,用于对比 save_interval: 100 # 每100步保存一次 save_dir: "core_space" # 保存目录 # ==================== Text2Wave 配置 ==================== # 模型设置 text2wave_model: # 原始模型: google/long-t5-tglobal-base model_name: "google/t5-v1_1-small" # 损失函数配置 loss: # 损失权重 weights: wave_loss: 4.0 # 波损失(超二次元参数)权重 wave_contrastive_loss: 2.0 # 序列级对比损失权重 world_info_loss: 0.5 # 世界信息损失(相机,缩放,时间)权重 controllable_info_loss: 0.1 # 可控制信息损失(质量,摩擦,弹性)权重 pla_loss: 3.0 # 最小作用量约束损失权重 wave_contrastive: temperature: 0.2 # 对比分布温度 # 数据配置 data: # MOVi数据集配置 num_workers: 32 # 数据加载线程数 max_sequences: 100 # 最大序列数,-1表示使用所有数据,设置较小值用于快速测试 physics: gravity: 9.81 # 自由落体重力加速度(单位:m/s^2) collision_buffer: 1.05 # 判定碰撞时的半径放大系数 # ==================== Wave2Pixel 配置 ==================== # 网格配置 grid: size: 64 # 3D网格分辨率 prob_threshold: 0.5 # 世界坐标系配置 world_coordinate_system: enabled: true # 是否启用世界坐标系 world_scale: 10.0 # 世界坐标范围 ±10米 voxel_size: 0.05 # 体素大小 5cm near_plane: 0.1 # 近平面距离 far_plane: 50.0 # 远平面距离 predict_world_scale: true # 让模型预测世界缩放比例 world_scale_loss_weight: 0.1 # 世界缩放比例损失权重 - 增加到1.0以加快学习 # 相机配置 camera: default_view: "front" fov: 60 near: 0.1 far: 100.0 # 世界坐标系中的相机位置 views: front: [0, 0, 2] back: [0, 0, -2] left: [-2, 0, 0] right: [2, 0, 0] top: [0, 2, 0] bottom: [0, -2, 0] # 相机旋转角度 (pitch, yaw, roll) view_rotations: front: [0, 0, 0] back: [0, 3.14159, 0] # 180度旋转 left: [0, -1.5708, 0] # -90度旋转 right: [0, 1.5708, 0] # 90度旋转 top: [-1.5708, 0, 0] # -90度俯视 bottom: [1.5708, 0, 0] # 90度仰视 # 生成配置 generation: mode: "image" # "image" 或 "video" time: start: 0.0 end: 12.0 fps: 30 # 外部观察频率 timestep: 0.0 # 用于单帧图像生成 compute_wsf: false # 是否默认计算完整WSF场 output_dir: "core_space" # 默认输出目录 # 输出格式配置 output: format: "triple_channel" # 可选: "complex", "dual_channel", "triple_channel" third_channel: "amplitude" # 如果format为"triple_channel",第三通道的内容: "amplitude", "phase", "none" # Wave2Pixel相关的模型组件 model: wave_encoder: hidden_dim: 256 dropout: 0.1 feature_extractor: input_dim: 4 # 实部、虚部、振幅、相位 hidden_dim: 64 output_dim: 32 dropout: 0.1 # 重命名为pixel_net以匹配代码中的使用 pixel_net: channels: [32, 64, 128, 64, 4] # 最后4通道: RGB + 概率 kernel_size: 3 padding: 1 dropout: 0.1