traj-diffusion / config.yaml
wezteoh's picture
upload models
fac6f28 verified
data:
name: trajectory_nba_filling
full_seq_len: 30
delta_len: 30
num_agents: 11
coord_dim: 2
court_width: 94.0
court_height: 50.0
params:
train_batch_size: 128
val_batch_size: 512
full_seq_len: 30
num_agents: 11
coord_dim: 2
court_width: 94.0
court_height: 50.0
train_path: data/nba_train.npy
val_path: data/nba_test.npy
trajectory_key: trajectory
context_key: context
mask_key: obs_mask
position_0_key: position_0
context_fill:
- -4.0
- -4.0
include_delta_in_context: true
delta_context_fill:
- 0.0
- 0.0
delta_shift:
- 0.0
- 0.0
delta_scale:
- 1.0
- 1.0
masking:
train:
mixture:
- weight: 0.35
even:
mode: random
prefix_min: 1
prefix_max: 20
- weight: 0.35
agent:
mode: random
n_masked_min: 5
n_masked_max: 11
- weight: 0.3
hybrid:
combine: union
a:
even:
mode: random
prefix_min: 1
prefix_max: 20
b:
agent:
mode: random
n_masked_min: 5
n_masked_max: 11
val_mask_path: /root/code/gameplay-trajectory-diffusion/data/nba_test_mask_v1.npy
num_workers: 0
max_val_samples: 5120
model:
name: trajectory_filling_ddpm
trajectory_key: trajectory
context_key: context
mask_key: obs_mask
position_0_key: position_0
guidance_scale: 2.0
p_uncond: 0.1
log_blend_trajectory_video: false
timesteps: 1000
beta_schedule: linear
linear_start: 0.0001
linear_end: 0.02
cosine_s: 0.008
parameterization: eps
loss_type: l2
diffusion_loss_type: rescaled_mse
log_diagnostic_vb: true
vb_decoder_nll: continuous
model_var_type: learned
clip_denoised: false
legacy_posterior_log_variance: false
backbone:
_target_: src.modules.backbones.dit_backbone.DITBackbone
max_seq_len: 30
num_agents: 11
coord_dim: 2
context_channels: 2
context_dim: 256
n_temporal_layer: 4
d_model_temporal: 256
nhead_temporal: 8
d_ff_temporal: 512
n_spatial_layer: 4
d_model_spatial: 256
nhead_spatial: 8
d_ff_spatial: 512
num_timesteps: 1000
patch_size: 5
learn_sigma: true
seed: 42
hardware:
use_gpu: true
gpu_devices: 1
wandb:
enabled: true
project: trajectory-filling-ddpm
entity: null
name: dit-learnsigma-mixedmask-lr=1e-4
save_dir: ./wandb
log_model: false
logging:
backend: null
tensorboard:
save_dir: ./tensorboard_logs
name: null
trainer:
lightning:
max_epochs: 1500
accelerator: auto
devices: 1
precision: 32
log_every_n_steps: 50
check_val_every_n_epoch: 20
val_check_interval: null
limit_val_batches: null
deterministic: false
fast_dev_run: false
val_logging:
enabled: true
num_samples: 6
log_every_n_val_epochs: 1
optim:
learning_rate: 0.0001
betas:
- 0.9
- 0.999
weight_decay: 0.0
ema:
enabled: true
decay: 0.9999
use_num_updates: true
checkpoint:
monitor: val/jade_min_30frames
mode: min
save_top_k: 4
val_trajectory_metrics:
enabled: true
max_samples: 512
every_n_val_epochs: 1
num_paths: 20
horizon_stride: 30
metrics_start_t: 0
prediction_mode: pure
guidance_scale_override: 2.0
verbose: false
sampling:
method: dpm
dpm:
steps: 40
order: 3
skip_type: time_uniform
algorithm_type: dpmsolver++
solver_method: singlestep
lower_order_final: true
denoise_to_zero: false
sampling:
method: ancestral