File size: 3,757 Bytes
fac6f28 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 | data:
name: trajectory_nba_filling
full_seq_len: 30
delta_len: 30
num_agents: 11
coord_dim: 2
court_width: 94.0
court_height: 50.0
params:
train_batch_size: 128
val_batch_size: 512
full_seq_len: 30
num_agents: 11
coord_dim: 2
court_width: 94.0
court_height: 50.0
train_path: data/nba_train.npy
val_path: data/nba_test.npy
trajectory_key: trajectory
context_key: context
mask_key: obs_mask
position_0_key: position_0
context_fill:
- -4.0
- -4.0
include_delta_in_context: true
delta_context_fill:
- 0.0
- 0.0
delta_shift:
- 0.0
- 0.0
delta_scale:
- 1.0
- 1.0
masking:
train:
mixture:
- weight: 0.35
even:
mode: random
prefix_min: 1
prefix_max: 20
- weight: 0.35
agent:
mode: random
n_masked_min: 5
n_masked_max: 11
- weight: 0.3
hybrid:
combine: union
a:
even:
mode: random
prefix_min: 1
prefix_max: 20
b:
agent:
mode: random
n_masked_min: 5
n_masked_max: 11
val_mask_path: /root/code/gameplay-trajectory-diffusion/data/nba_test_mask_v1.npy
num_workers: 0
max_val_samples: 5120
model:
name: trajectory_filling_ddpm
trajectory_key: trajectory
context_key: context
mask_key: obs_mask
position_0_key: position_0
guidance_scale: 2.0
p_uncond: 0.1
log_blend_trajectory_video: false
timesteps: 1000
beta_schedule: linear
linear_start: 0.0001
linear_end: 0.02
cosine_s: 0.008
parameterization: eps
loss_type: l2
diffusion_loss_type: rescaled_mse
log_diagnostic_vb: true
vb_decoder_nll: continuous
model_var_type: learned
clip_denoised: false
legacy_posterior_log_variance: false
backbone:
_target_: src.modules.backbones.dit_backbone.DITBackbone
max_seq_len: 30
num_agents: 11
coord_dim: 2
context_channels: 2
context_dim: 256
n_temporal_layer: 4
d_model_temporal: 256
nhead_temporal: 8
d_ff_temporal: 512
n_spatial_layer: 4
d_model_spatial: 256
nhead_spatial: 8
d_ff_spatial: 512
num_timesteps: 1000
patch_size: 5
learn_sigma: true
seed: 42
hardware:
use_gpu: true
gpu_devices: 1
wandb:
enabled: true
project: trajectory-filling-ddpm
entity: null
name: dit-learnsigma-mixedmask-lr=1e-4
save_dir: ./wandb
log_model: false
logging:
backend: null
tensorboard:
save_dir: ./tensorboard_logs
name: null
trainer:
lightning:
max_epochs: 1500
accelerator: auto
devices: 1
precision: 32
log_every_n_steps: 50
check_val_every_n_epoch: 20
val_check_interval: null
limit_val_batches: null
deterministic: false
fast_dev_run: false
val_logging:
enabled: true
num_samples: 6
log_every_n_val_epochs: 1
optim:
learning_rate: 0.0001
betas:
- 0.9
- 0.999
weight_decay: 0.0
ema:
enabled: true
decay: 0.9999
use_num_updates: true
checkpoint:
monitor: val/jade_min_30frames
mode: min
save_top_k: 4
val_trajectory_metrics:
enabled: true
max_samples: 512
every_n_val_epochs: 1
num_paths: 20
horizon_stride: 30
metrics_start_t: 0
prediction_mode: pure
guidance_scale_override: 2.0
verbose: false
sampling:
method: dpm
dpm:
steps: 40
order: 3
skip_type: time_uniform
algorithm_type: dpmsolver++
solver_method: singlestep
lower_order_final: true
denoise_to_zero: false
sampling:
method: ancestral
|