| # Configuration for Language Table World Model Training with Diffusion Forcing | |
| # Dynamics Model Class | |
| dynamics_class: "DiffusionForcing_WM" | |
| # Model identifier for DIT_CLASS_MAP | |
| model_name: "VideoDiT" | |
| # Configuration passed to the DiT model constructor | |
| model_config: | |
| in_channels: 16 # Latent channels from Wan VAE | |
| patch_size: 2 | |
| dim: 1024 # Hidden dimension | |
| num_layers: 16 | |
| num_heads: 16 | |
| action_dim: 2 # Language Table action dimension | |
| action_compress_rate: 4 # Compresses action sequence (1+4k) to latent sequence (1+k) | |
| max_frames: 33 # Max frames (T=33 -> 1+4*8) | |
| action_dropout_prob: 0.0 # close CFG for now | |
| temporal_causal: true # Diffusion Forcing REQUIRES causal attention | |
| vae_name: "WanVAE" | |
| vae_config: | |
| - "/storage/ice-shared/ae8803che/hxue/data/checkpoint/wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth" | |
| scheduler: "FlowMatch" # Will be instantiated in dynamics class | |
| training_timesteps: 1000 | |
| # Dataset Configuration | |
| dataset: | |
| name: "language_table" | |
| train_seq_len: 9 # Train on shorter sequences (T=9 -> 1+4*2) | |
| eval_seq_len: 17 # Evaluate on longer sequences (T=17 -> 1+4*4) | |
| train_test_split: 50 # 50:1 split | |
| # Training Hyperparameters | |
| training: | |
| batch_size: 8 # Batch size per GPU | |
| learning_rate: 1e-4 | |
| num_epochs: 2000 | |
| grad_clip: 1.0 | |
| checkpoint_freq: 2000 # Numbered checkpoints for eval | |
| latest_freq: 500 # Only updates latest.pt for resuming | |
| val_freq: 500 # Video Logging (Less frequent) | |
| eval_freq: 100 # MSE Rollout (More frequent) | |
| log_freq: 5 # Steps | |
| num_workers: 4 | |
| gen_mode: "autoregressive" # Mode for evaluation generation: "autoregressive" or "parallel" | |
| inference_steps: 50 # Number of denoising steps for generation | |
| # Distributed Training | |
| distributed: | |
| use_ddp: true | |
| use_fsdp: false # Toggle FSDP | |
| # WandB Configuration | |
| wandb: | |
| project: "world_model" | |
| entity: null # Set if needed | |
| api_key: "62da90010e5c8cc94a66361396c57cea8c2c1e21" | |
| run_name: "lang_table_df_v1" | |