File size: 2,136 Bytes
f17ae24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# Configuration for Language Table World Model Training with Diffusion Forcing

# Dynamics Model Class
dynamics_class: "DiffusionForcing_WM"

# Model identifier for DIT_CLASS_MAP
model_name: "VideoDiT"

# Configuration passed to the DiT model constructor
model_config:
  in_channels: 16            # Latent channels from Wan VAE
  patch_size: 2
  dim: 1024                  # Hidden dimension
  num_layers: 16
  num_heads: 16
  action_dim: 2              # Language Table action dimension
  action_compress_rate: 4    # Compresses action sequence (1+4k) to latent sequence (1+k)
  max_frames: 33             # Max frames (T=33 -> 1+4*8)
  action_dropout_prob: 0.0   # close CFG for now
  temporal_causal: true      # Diffusion Forcing REQUIRES causal attention
  vae_name: "WanVAE"
  vae_config:
    - "/storage/ice-shared/ae8803che/hxue/data/checkpoint/wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"
  scheduler: "FlowMatch"     # Will be instantiated in dynamics class
  training_timesteps: 1000

# Dataset Configuration
dataset:
  name: "language_table"
  train_seq_len: 9           # Train on shorter sequences (T=9 -> 1+4*2)
  eval_seq_len: 17           # Evaluate on longer sequences (T=17 -> 1+4*4)
  train_test_split: 50       # 50:1 split

# Training Hyperparameters
training:
  batch_size: 8             # Batch size per GPU
  learning_rate: 1e-4
  num_epochs: 2000
  grad_clip: 1.0
  checkpoint_freq: 2000      # Numbered checkpoints for eval
  latest_freq: 500           # Only updates latest.pt for resuming
  val_freq: 500              # Video Logging (Less frequent)
  eval_freq: 100             # MSE Rollout (More frequent)
  log_freq: 5              # Steps
  num_workers: 4
  gen_mode: "autoregressive" # Mode for evaluation generation: "autoregressive" or "parallel"
  inference_steps: 50        # Number of denoising steps for generation

# Distributed Training
distributed:
  use_ddp: true
  use_fsdp: false            # Toggle FSDP

# WandB Configuration
wandb:
  project: "world_model"
  entity: null               # Set if needed
  api_key: "62da90010e5c8cc94a66361396c57cea8c2c1e21"
  run_name: "lang_table_df_v1"