File size: 3,757 Bytes
fac6f28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
data:
  name: trajectory_nba_filling
  full_seq_len: 30
  delta_len: 30
  num_agents: 11
  coord_dim: 2
  court_width: 94.0
  court_height: 50.0
  params:
    train_batch_size: 128
    val_batch_size: 512
    full_seq_len: 30
    num_agents: 11
    coord_dim: 2
    court_width: 94.0
    court_height: 50.0
    train_path: data/nba_train.npy
    val_path: data/nba_test.npy
    trajectory_key: trajectory
    context_key: context
    mask_key: obs_mask
    position_0_key: position_0
    context_fill:
    - -4.0
    - -4.0
    include_delta_in_context: true
    delta_context_fill:
    - 0.0
    - 0.0
    delta_shift:
    - 0.0
    - 0.0
    delta_scale:
    - 1.0
    - 1.0
    masking:
      train:
        mixture:
        - weight: 0.35
          even:
            mode: random
            prefix_min: 1
            prefix_max: 20
        - weight: 0.35
          agent:
            mode: random
            n_masked_min: 5
            n_masked_max: 11
        - weight: 0.3
          hybrid:
            combine: union
            a:
              even:
                mode: random
                prefix_min: 1
                prefix_max: 20
            b:
              agent:
                mode: random
                n_masked_min: 5
                n_masked_max: 11
      val_mask_path: /root/code/gameplay-trajectory-diffusion/data/nba_test_mask_v1.npy
    num_workers: 0
    max_val_samples: 5120
model:
  name: trajectory_filling_ddpm
  trajectory_key: trajectory
  context_key: context
  mask_key: obs_mask
  position_0_key: position_0
  guidance_scale: 2.0
  p_uncond: 0.1
  log_blend_trajectory_video: false
  timesteps: 1000
  beta_schedule: linear
  linear_start: 0.0001
  linear_end: 0.02
  cosine_s: 0.008
  parameterization: eps
  loss_type: l2
  diffusion_loss_type: rescaled_mse
  log_diagnostic_vb: true
  vb_decoder_nll: continuous
  model_var_type: learned
  clip_denoised: false
  legacy_posterior_log_variance: false
  backbone:
    _target_: src.modules.backbones.dit_backbone.DITBackbone
    max_seq_len: 30
    num_agents: 11
    coord_dim: 2
    context_channels: 2
    context_dim: 256
    n_temporal_layer: 4
    d_model_temporal: 256
    nhead_temporal: 8
    d_ff_temporal: 512
    n_spatial_layer: 4
    d_model_spatial: 256
    nhead_spatial: 8
    d_ff_spatial: 512
    num_timesteps: 1000
    patch_size: 5
    learn_sigma: true
seed: 42
hardware:
  use_gpu: true
  gpu_devices: 1
wandb:
  enabled: true
  project: trajectory-filling-ddpm
  entity: null
  name: dit-learnsigma-mixedmask-lr=1e-4
  save_dir: ./wandb
  log_model: false
logging:
  backend: null
  tensorboard:
    save_dir: ./tensorboard_logs
    name: null
trainer:
  lightning:
    max_epochs: 1500
    accelerator: auto
    devices: 1
    precision: 32
    log_every_n_steps: 50
    check_val_every_n_epoch: 20
    val_check_interval: null
    limit_val_batches: null
    deterministic: false
    fast_dev_run: false
  val_logging:
    enabled: true
    num_samples: 6
    log_every_n_val_epochs: 1
  optim:
    learning_rate: 0.0001
    betas:
    - 0.9
    - 0.999
    weight_decay: 0.0
  ema:
    enabled: true
    decay: 0.9999
    use_num_updates: true
  checkpoint:
    monitor: val/jade_min_30frames
    mode: min
    save_top_k: 4
  val_trajectory_metrics:
    enabled: true
    max_samples: 512
    every_n_val_epochs: 1
    num_paths: 20
    horizon_stride: 30
    metrics_start_t: 0
    prediction_mode: pure
    guidance_scale_override: 2.0
    verbose: false
    sampling:
      method: dpm
      dpm:
        steps: 40
        order: 3
        skip_type: time_uniform
        algorithm_type: dpmsolver++
        solver_method: singlestep
        lower_order_final: true
        denoise_to_zero: false
sampling:
  method: ancestral