t1an's picture
Upload folder using huggingface_hub
f17ae24 verified
# Configuration for Language Table World Model Training with Diffusion Forcing
# Dynamics Model Class
dynamics_class: "DiffusionForcing_WM"
# Model identifier for DIT_CLASS_MAP
model_name: "VideoDiT"
# Configuration passed to the DiT model constructor
model_config:
in_channels: 16 # Latent channels from Wan VAE
patch_size: 2
dim: 1024 # Hidden dimension
num_layers: 16
num_heads: 16
action_dim: 2 # Language Table action dimension
action_compress_rate: 4 # Compresses action sequence (1+4k) to latent sequence (1+k)
max_frames: 33 # Max frames (T=33 -> 1+4*8)
action_dropout_prob: 0.0 # close CFG for now
temporal_causal: true # Diffusion Forcing REQUIRES causal attention
vae_name: "WanVAE"
vae_config:
- "/storage/ice-shared/ae8803che/hxue/data/checkpoint/wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"
scheduler: "FlowMatch" # Will be instantiated in dynamics class
training_timesteps: 1000
# Dataset Configuration
dataset:
name: "language_table"
train_seq_len: 9 # Train on shorter sequences (T=9 -> 1+4*2)
eval_seq_len: 17 # Evaluate on longer sequences (T=17 -> 1+4*4)
train_test_split: 50 # 50:1 split
# Training Hyperparameters
training:
batch_size: 8 # Batch size per GPU
learning_rate: 1e-4
num_epochs: 2000
grad_clip: 1.0
checkpoint_freq: 2000 # Numbered checkpoints for eval
latest_freq: 500 # Only updates latest.pt for resuming
val_freq: 500 # Video Logging (Less frequent)
eval_freq: 100 # MSE Rollout (More frequent)
log_freq: 5 # Steps
num_workers: 4
gen_mode: "autoregressive" # Mode for evaluation generation: "autoregressive" or "parallel"
inference_steps: 50 # Number of denoising steps for generation
# Distributed Training
distributed:
use_ddp: true
use_fsdp: false # Toggle FSDP
# WandB Configuration
wandb:
project: "world_model"
entity: null # Set if needed
api_key: "62da90010e5c8cc94a66361396c57cea8c2c1e21"
run_name: "lang_table_df_v1"