wezteoh commited on
Commit
fac6f28
·
verified ·
1 Parent(s): c026f6b

upload models

Browse files
config.yaml ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ name: trajectory_nba_filling
3
+ full_seq_len: 30
4
+ delta_len: 30
5
+ num_agents: 11
6
+ coord_dim: 2
7
+ court_width: 94.0
8
+ court_height: 50.0
9
+ params:
10
+ train_batch_size: 128
11
+ val_batch_size: 512
12
+ full_seq_len: 30
13
+ num_agents: 11
14
+ coord_dim: 2
15
+ court_width: 94.0
16
+ court_height: 50.0
17
+ train_path: data/nba_train.npy
18
+ val_path: data/nba_test.npy
19
+ trajectory_key: trajectory
20
+ context_key: context
21
+ mask_key: obs_mask
22
+ position_0_key: position_0
23
+ context_fill:
24
+ - -4.0
25
+ - -4.0
26
+ include_delta_in_context: true
27
+ delta_context_fill:
28
+ - 0.0
29
+ - 0.0
30
+ delta_shift:
31
+ - 0.0
32
+ - 0.0
33
+ delta_scale:
34
+ - 1.0
35
+ - 1.0
36
+ masking:
37
+ train:
38
+ mixture:
39
+ - weight: 0.35
40
+ even:
41
+ mode: random
42
+ prefix_min: 1
43
+ prefix_max: 20
44
+ - weight: 0.35
45
+ agent:
46
+ mode: random
47
+ n_masked_min: 5
48
+ n_masked_max: 11
49
+ - weight: 0.3
50
+ hybrid:
51
+ combine: union
52
+ a:
53
+ even:
54
+ mode: random
55
+ prefix_min: 1
56
+ prefix_max: 20
57
+ b:
58
+ agent:
59
+ mode: random
60
+ n_masked_min: 5
61
+ n_masked_max: 11
62
+ val_mask_path: /root/code/gameplay-trajectory-diffusion/data/nba_test_mask_v1.npy
63
+ num_workers: 0
64
+ max_val_samples: 5120
65
+ model:
66
+ name: trajectory_filling_ddpm
67
+ trajectory_key: trajectory
68
+ context_key: context
69
+ mask_key: obs_mask
70
+ position_0_key: position_0
71
+ guidance_scale: 2.0
72
+ p_uncond: 0.1
73
+ log_blend_trajectory_video: false
74
+ timesteps: 1000
75
+ beta_schedule: linear
76
+ linear_start: 0.0001
77
+ linear_end: 0.02
78
+ cosine_s: 0.008
79
+ parameterization: eps
80
+ loss_type: l2
81
+ diffusion_loss_type: rescaled_mse
82
+ log_diagnostic_vb: true
83
+ vb_decoder_nll: continuous
84
+ model_var_type: learned
85
+ clip_denoised: false
86
+ legacy_posterior_log_variance: false
87
+ backbone:
88
+ _target_: src.modules.backbones.dit_backbone.DITBackbone
89
+ max_seq_len: 30
90
+ num_agents: 11
91
+ coord_dim: 2
92
+ context_channels: 2
93
+ context_dim: 256
94
+ n_temporal_layer: 4
95
+ d_model_temporal: 256
96
+ nhead_temporal: 8
97
+ d_ff_temporal: 512
98
+ n_spatial_layer: 4
99
+ d_model_spatial: 256
100
+ nhead_spatial: 8
101
+ d_ff_spatial: 512
102
+ num_timesteps: 1000
103
+ patch_size: 5
104
+ learn_sigma: true
105
+ seed: 42
106
+ hardware:
107
+ use_gpu: true
108
+ gpu_devices: 1
109
+ wandb:
110
+ enabled: true
111
+ project: trajectory-filling-ddpm
112
+ entity: null
113
+ name: dit-learnsigma-mixedmask-lr=1e-4
114
+ save_dir: ./wandb
115
+ log_model: false
116
+ logging:
117
+ backend: null
118
+ tensorboard:
119
+ save_dir: ./tensorboard_logs
120
+ name: null
121
+ trainer:
122
+ lightning:
123
+ max_epochs: 1500
124
+ accelerator: auto
125
+ devices: 1
126
+ precision: 32
127
+ log_every_n_steps: 50
128
+ check_val_every_n_epoch: 20
129
+ val_check_interval: null
130
+ limit_val_batches: null
131
+ deterministic: false
132
+ fast_dev_run: false
133
+ val_logging:
134
+ enabled: true
135
+ num_samples: 6
136
+ log_every_n_val_epochs: 1
137
+ optim:
138
+ learning_rate: 0.0001
139
+ betas:
140
+ - 0.9
141
+ - 0.999
142
+ weight_decay: 0.0
143
+ ema:
144
+ enabled: true
145
+ decay: 0.9999
146
+ use_num_updates: true
147
+ checkpoint:
148
+ monitor: val/jade_min_30frames
149
+ mode: min
150
+ save_top_k: 4
151
+ val_trajectory_metrics:
152
+ enabled: true
153
+ max_samples: 512
154
+ every_n_val_epochs: 1
155
+ num_paths: 20
156
+ horizon_stride: 30
157
+ metrics_start_t: 0
158
+ prediction_mode: pure
159
+ guidance_scale_override: 2.0
160
+ verbose: false
161
+ sampling:
162
+ method: dpm
163
+ dpm:
164
+ steps: 40
165
+ order: 3
166
+ skip_type: time_uniform
167
+ algorithm_type: dpmsolver++
168
+ solver_method: singlestep
169
+ lower_order_final: true
170
+ denoise_to_zero: false
171
+ sampling:
172
+ method: ancestral
traj-ddpm-1459-456980-val_jade_min_30frames=3.9700.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c97014510fe8d709f45969e75b1d88e1322ab4ed11f78faa4b6a14d8a765af96
3
+ size 122520563