# Configuration for Language Table World Model Training with Diffusion Forcing # Dynamics Model Class dynamics_class: "DiffusionForcing_WM" # Model identifier for DIT_CLASS_MAP model_name: "VideoDiT" # Configuration passed to the DiT model constructor model_config: in_channels: 16 # Latent channels from Wan VAE patch_size: 2 dim: 1024 # Hidden dimension num_layers: 16 num_heads: 16 action_dim: 2 # Language Table action dimension action_compress_rate: 4 # Compresses action sequence (1+4k) to latent sequence (1+k) max_frames: 33 # Max frames (T=33 -> 1+4*8) action_dropout_prob: 0.0 # close CFG for now temporal_causal: true # Diffusion Forcing REQUIRES causal attention vae_name: "WanVAE" vae_config: - "/storage/ice-shared/ae8803che/hxue/data/checkpoint/wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth" scheduler: "FlowMatch" # Will be instantiated in dynamics class training_timesteps: 1000 # Dataset Configuration dataset: name: "language_table" train_seq_len: 9 # Train on shorter sequences (T=9 -> 1+4*2) eval_seq_len: 17 # Evaluate on longer sequences (T=17 -> 1+4*4) train_test_split: 50 # 50:1 split # Training Hyperparameters training: batch_size: 8 # Batch size per GPU learning_rate: 1e-4 num_epochs: 2000 grad_clip: 1.0 checkpoint_freq: 2000 # Numbered checkpoints for eval latest_freq: 500 # Only updates latest.pt for resuming val_freq: 500 # Video Logging (Less frequent) eval_freq: 100 # MSE Rollout (More frequent) log_freq: 5 # Steps num_workers: 4 gen_mode: "autoregressive" # Mode for evaluation generation: "autoregressive" or "parallel" inference_steps: 50 # Number of denoising steps for generation # Distributed Training distributed: use_ddp: true use_fsdp: false # Toggle FSDP # WandB Configuration wandb: project: "world_model" entity: null # Set if needed api_key: "62da90010e5c8cc94a66361396c57cea8c2c1e21" run_name: "lang_table_df_v1"