TraceGen / train.yaml
JayLee131's picture
Upload train.yaml with huggingface_hub
532bef1 verified
project: "siglip_trajectory_diffusion"
seed: 1337
num_kps: 400 # For dataset compatibility
d_model: 768
siglip_ckpt: "google/siglip-base-patch16-384"
freeze_vision: true
freeze_t5: true
t5_model: "t5-base"
absolute_action: false # true
dinov2_model: vit_large_patch16_dinov3.lvd1689m #vit_large_patch16_dinov3.lvd1689m # vit_base_patch16_dinov3.lvd1689m
trajectory_horizon: 32 # Number of frames to generate (future frames only)
# Model architecture
model:
vision_encoder:
patch_size: 16
image_size: 384
text_encoder:
max_length: 64
decoder:
latent_dim: 768 # Trunk output dimension
num_attention_heads: 8 # CogVideoX attention heads
attention_head_dim: 64 # CogVideoX attention head dimension
put_frames_in_channels: 2 # Put frames in channels instead of temporal dimension (B T C H W -> B T/4 C*4 H W)
in_channels: 3 # Number of input channels in latent space (CogVideoX)
out_channels: 3 # Number of output channels in latent space (CogVideoX)
num_layers: 4 # Number of CogVideoX transformer layers
num_frames: ${trajectory_horizon} # Number of frames to generate (future frames only)
frame_size: 20 # Size of latent frames (16x16)
patch_size: 2 # Patch size for latents
patch_size_t: 1 # Patch size for temporal dimension
max_text_seq_length: 704 # Maximum text sequence length for CogVideoX
text_embed_dim: 768 # Text embedding dimension for CogVideoX
use_rotary_positional_embeddings: false # Use rotary embeddings
scale_factor: 0.7 # CogVideoX-specific scaling factor
scale_factor_spatial: 1 # Spatial scaling factor
scale_factor_temporal: 1 # Temporal scaling factor
enable_encoder_hidden_states_grad: true # Enable gradient flow through conditioning
# device: "cuda:1"
# Training configuration
train:
epochs: 700
batch_size: 32
lr_decoder: 2.0e-4
lr_backbone: 2.0e-5
weight_decay: 0.05
warmup_steps: 100
clip_grad_norm: 1.0
save_every: 1
num_log_steps_per_epoch: 0
eval_every: 1
visualize_every: 1
visualize_during_validation: true
# Data configuration
# NOTE: dataset_dirs should be overridden in train.local.yaml
data:
dataset_dirs: [] # Override this in .local.yaml with your machine-specific paths
cache_dir: "./dataset_cache" # Where to store cache files
val_split: 0.01 # 1% of episodes for validation
random_seed: 42 # For reproducible train/val split
num_workers: 16
pin_memory: false
augmentation:
# Logging configuration
# NOTE: checkpoint_dir should be overridden in train.local.yaml
logging:
wandb_project: "tracegen"
use_wandb: false # Set to false to disable wandb
log_every: 100
save_dir: "./checkpoints"
checkpoint_dir: "./checkpoint/" # Override this in .local.yaml with your machine-specific path
# Hardware configuration
hardware:
device: "cuda"
mixed_precision: true
compile_model: true
# NOTE: test_path should be overridden in train.local.yaml
test_path: null # Override this in .local.yaml with your machine-specific path