pengzhenghao commited on
Commit
89f8755
·
1 Parent(s): b8a7066

Set up self-contained Gradio Space

Browse files

Bundle the app code and tiny demo dataset so the Hugging Face Space can boot directly into the SceneStreamer demo with sensible headless defaults.

Made-with: Cursor

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +12 -0
  2. README.md +28 -4
  3. app.py +58 -0
  4. cfgs/motion_default.yaml +211 -0
  5. cfgs/scenestreamer-base-large.yaml +96 -0
  6. cfgs/scenestreamer-base-small.yaml +96 -0
  7. cfgs/scenestreamer-base-xl.yaml +97 -0
  8. cfgs/scenestreamer-full-large-nors.yaml +99 -0
  9. cfgs/scenestreamer-full-large.yaml +99 -0
  10. cfgs/scenestreamer-full-small.yaml +96 -0
  11. cfgs/scenestreamer-full-xl.yaml +100 -0
  12. data/20scenarios/cache/sd_training.tfrecord-00000-of-01000_101d4e5775093d0c.pkl +3 -0
  13. data/20scenarios/cache/sd_training.tfrecord-00000-of-01000_16d4f837fb57ff1.pkl +3 -0
  14. data/20scenarios/cache/sd_training.tfrecord-00000-of-01000_17100879b93ceb61.pkl +3 -0
  15. data/20scenarios/cache/sd_training.tfrecord-00000-of-01000_c403d5992cab9e0.pkl +3 -0
  16. data/20scenarios/process.ipynb +128 -0
  17. data/20scenarios/sd_training.tfrecord-00000-of-01000_101d4e5775093d0c.pkl +3 -0
  18. data/20scenarios/sd_training.tfrecord-00000-of-01000_10c3969f1eb158d.pkl +3 -0
  19. data/20scenarios/sd_training.tfrecord-00000-of-01000_1109b0038ed8f25a.pkl +3 -0
  20. data/20scenarios/sd_training.tfrecord-00000-of-01000_116d257e98878d94.pkl +3 -0
  21. data/20scenarios/sd_training.tfrecord-00000-of-01000_11cdace2c1445900.pkl +3 -0
  22. data/20scenarios/sd_training.tfrecord-00000-of-01000_12a40f114f1ec5fc.pkl +3 -0
  23. data/20scenarios/sd_training.tfrecord-00000-of-01000_149f682e19454efa.pkl +3 -0
  24. data/20scenarios/sd_training.tfrecord-00000-of-01000_16361c8c522cf0e.pkl +3 -0
  25. data/20scenarios/sd_training.tfrecord-00000-of-01000_1663adf01133d82.pkl +3 -0
  26. data/20scenarios/sd_training.tfrecord-00000-of-01000_16b7c9a8ae6e89f1.pkl +3 -0
  27. data/20scenarios/sd_training.tfrecord-00000-of-01000_16d4f837fb57ff1.pkl +3 -0
  28. data/20scenarios/sd_training.tfrecord-00000-of-01000_17100879b93ceb61.pkl +3 -0
  29. data/20scenarios/sd_training.tfrecord-00000-of-01000_17fa718cd1251a26.pkl +3 -0
  30. data/20scenarios/sd_training.tfrecord-00000-of-01000_18840a098288507f.pkl +3 -0
  31. data/20scenarios/sd_training.tfrecord-00000-of-01000_18f5ce249ee2e949.pkl +3 -0
  32. data/20scenarios/sd_training.tfrecord-00000-of-01000_1a7143a44e480ca6.pkl +3 -0
  33. data/20scenarios/sd_training.tfrecord-00000-of-01000_1a8cc570d620bd31.pkl +3 -0
  34. data/20scenarios/sd_training.tfrecord-00000-of-01000_2a1e44d405a6833f.pkl +3 -0
  35. data/20scenarios/sd_training.tfrecord-00000-of-01000_8a346109094cd5aa.pkl +3 -0
  36. data/20scenarios/sd_training.tfrecord-00000-of-01000_c403d5992cab9e0.pkl +3 -0
  37. packages.txt +7 -0
  38. pyproject.toml +80 -0
  39. requirements.txt +1 -0
  40. scenestreamer/__init__.py +0 -0
  41. scenestreamer/cli.py +293 -0
  42. scenestreamer/clustering.sh +7 -0
  43. scenestreamer/dataset/__init__.py +0 -0
  44. scenestreamer/dataset/constants.py +44 -0
  45. scenestreamer/dataset/datamodule.py +49 -0
  46. scenestreamer/dataset/dataset.py +630 -0
  47. scenestreamer/dataset/make_lmdb.py +233 -0
  48. scenestreamer/dataset/preprocess_action_label.py +293 -0
  49. scenestreamer/dataset/preprocessor.py +0 -0
  50. scenestreamer/dataset/scenarionet_utils.py +239 -0
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .venv/
2
+ __pycache__/
3
+ *.pyc
4
+ *.pyo
5
+ .DS_Store
6
+ artifacts/
7
+ outputs/
8
+ lightning_logs/
9
+ wandb/
10
+ scenestreamer/outputs/
11
+ scenestreamer/lightning_logs/
12
+ scenestreamer/eval/outputs/
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: SceneStreamer
3
- emoji: 😻
4
- colorFrom: green
5
- colorTo: gray
6
  sdk: gradio
7
  sdk_version: 6.9.0
8
  app_file: app.py
@@ -10,4 +10,28 @@ pinned: false
10
  license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: SceneStreamer
3
+ emoji: 🚗
4
+ colorFrom: blue
5
+ colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 6.9.0
8
  app_file: app.py
 
10
  license: mit
11
  ---
12
 
13
+ # SceneStreamer Space
14
+
15
+ This Space hosts the interactive Gradio demo for `SceneStreamer`.
16
+
17
+ What is included here:
18
+
19
+ - the Gradio app entrypoint in `app.py`
20
+ - the SceneStreamer package code needed by the demo
21
+ - a tiny bundled ScenarioNet subset in `data/20scenarios`
22
+
23
+ Default behavior:
24
+
25
+ - the app loads the bundled demo subset automatically
26
+ - the model checkpoint is fetched from the Hugging Face Hub by default
27
+ - `SCENESTREAMER_DEVICE` defaults to `cpu`
28
+
29
+ Optional Space variables:
30
+
31
+ - `SCENESTREAMER_DATASET_DIR`
32
+ - `SCENESTREAMER_HF_REPO`
33
+ - `SCENESTREAMER_HF_FILE`
34
+ - `SCENESTREAMER_CKPT`
35
+ - `SCENESTREAMER_DEVICE`
36
+
37
+ If the app shows the setup screen instead of the demo, the dataset path is missing or the demo subset was not uploaded with the repo.
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from pathlib import Path
5
+
6
+ os.environ.setdefault("MPLBACKEND", "Agg")
7
+ os.environ.setdefault("PYGAME_HIDE_SUPPORT_PROMPT", "1")
8
+ os.environ.setdefault("SDL_AUDIODRIVER", "dummy")
9
+ os.environ.setdefault("SDL_VIDEODRIVER", "dummy")
10
+
11
+ import gradio as gr
12
+
13
+ from scenestreamer.gradio_ui.demo_app import DEFAULT_HF_FILE, DEFAULT_HF_REPO, build_demo
14
+
15
+
16
+ def _build_space_demo() -> gr.Blocks:
17
+ dataset_dir = os.environ.get("SCENESTREAMER_DATASET_DIR", "data/20scenarios")
18
+ hf_repo = os.environ.get("SCENESTREAMER_HF_REPO", DEFAULT_HF_REPO)
19
+ hf_file = os.environ.get("SCENESTREAMER_HF_FILE", DEFAULT_HF_FILE)
20
+ ckpt = os.environ.get("SCENESTREAMER_CKPT") or None
21
+ device = os.environ.get("SCENESTREAMER_DEVICE", "cpu")
22
+
23
+ if not Path(dataset_dir).exists():
24
+ with gr.Blocks(title="SceneStreamer Space Setup") as demo:
25
+ gr.Markdown("## SceneStreamer Space Setup Required")
26
+ gr.Markdown(
27
+ "This Space needs a local ScenarioNet dataset directory before the interactive demo can start.\n\n"
28
+ f"Current `SCENESTREAMER_DATASET_DIR`: `{dataset_dir}`"
29
+ )
30
+ gr.Markdown(
31
+ "Set Space variables or attach storage, then restart the Space:\n"
32
+ "- `SCENESTREAMER_DATASET_DIR`\n"
33
+ "- `SCENESTREAMER_HF_REPO` (optional)\n"
34
+ "- `SCENESTREAMER_HF_FILE` (optional)\n"
35
+ "- `SCENESTREAMER_CKPT` (optional local checkpoint)\n"
36
+ "- `SCENESTREAMER_DEVICE` (default `cpu`)"
37
+ )
38
+ gr.Markdown(
39
+ "This repo is expected to include a tiny bundled demo subset under `data/20scenarios`. "
40
+ "If you are seeing this page after pushing the repo, the demo data was likely not uploaded."
41
+ )
42
+ return demo
43
+
44
+ return build_demo(
45
+ dataset_dir=dataset_dir,
46
+ hf_repo=hf_repo,
47
+ hf_file=hf_file,
48
+ ckpt=ckpt,
49
+ device=device,
50
+ )
51
+
52
+
53
+ demo = _build_space_demo()
54
+
55
+
56
+ if __name__ == "__main__":
57
+ demo.launch()
58
+
cfgs/motion_default.yaml ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+
4
+ # Experiment related
5
+ exp_name: 'default'
6
+ seed: 0
7
+ epochs: 50
8
+ batch_size: 10
9
+ val_batch_size: 4
10
+ num_workers: 16
11
+ val_num_workers: 16
12
+ num_sanity_val_steps: 100
13
+ val_check_interval: 1.0
14
+ wandb: False
15
+ log_dir: Null
16
+ limit_train_batches: -1
17
+ limit_val_batches: -1
18
+ prefetch_factor: 2
19
+ ckpt: Null
20
+ eval: False
21
+ pretrain: Null
22
+ deterministic: False
23
+ detect_anomaly: False
24
+ check_val_every_n_epoch: 1
25
+
26
+ USE_RL_FINETUNING: False
27
+
28
+
29
+ # Turn both on when training TG to match TrafficGen's metrics.
30
+ LIMIT_MAP_RANGE: False
31
+ FOLLOW_TRAFFICGEN: False
32
+ FORCE_SDC_FOR_TRAFFICGEN: False
33
+ ONLY_LANE_FOR_TRAFFICGEN: False
34
+
35
+ # True then agent info will not be used in encoder,
36
+ # and new tokens for history will be added for decoder.
37
+ GPT_STYLE: false
38
+ REMOVE_AGENT_FROM_SCENE_ENCODER: false
39
+ USE_DIFFUSION: false
40
+ USE_ADALN: null
41
+ BACKWARD_PREDICTION: false
42
+
43
+
44
+ USE_DESTINATION: false
45
+
46
+ TF_DEST: True
47
+
48
+
49
+ ADD_CONTOUR_RELATION: false
50
+
51
+ DELTA_TOKENIZER_FILE_NAME: "1030_argsort_less_256_128_128.pkl"
52
+
53
+
54
+ USE_TRAFFICGEN: false
55
+ TRAIN_TRAFFICGEN: null
56
+ USE_MOTION: true
57
+
58
+ EVAL_MOTION: true
59
+ EVAL_TRAFFICGEN: false
60
+
61
+ DELTA_POS_IS_VELOCITY: false
62
+
63
+ SIMPLE_RELATION: false
64
+ SIMPLE_RELATION_FACTOR: 1
65
+
66
+ RECONSTRUCT_MAP: false
67
+
68
+ UPDATE_RELATION: false
69
+
70
+ REMOVE_REL_NORM: false
71
+
72
+ DATA:
73
+ TRAINING_DATA_DIR: '/data/datasets/scenarionet/waymo/training'
74
+ TEST_DATA_DIR: '/data/datasets/scenarionet/waymo/validation'
75
+ ADV_INFO_PATH: 'data/all_adv.pkl'
76
+ SAMPLE_INTERVAL_TRAINING: 1
77
+ SAMPLE_INTERVAL_TEST: 1
78
+ SD_PASSTHROUGH: false
79
+ ALLOW_CACHE: false
80
+ RETURN_HALFWAY: false # Only used when generating LMDB dataset
81
+ USE_LMDB: false
82
+ USE_CACHE: false
83
+
84
+ PREPROCESSING:
85
+ MAX_VECTORS: 128
86
+ MAX_MAP_FEATURES: 512
87
+ MAX_LENGTH_PER_MAP_FEATURE: 10000 # Useless
88
+ MAX_AGENTS: 128
89
+ MAX_TRAFFIC_LIGHTS: 64
90
+ PADDING_TO_MAX: false
91
+ keep_all_data: false # for debug
92
+ ADD_SDC_TO_OBJECT_OF_INTEREST: true # Should be True when WOSAC
93
+ REMOVE_TRAFFIC_LIGHT_STATE: false
94
+ TRUNCATE_TIME: -1
95
+
96
+ TRAINING:
97
+ PREDICT_ALL_AGENTS: true
98
+
99
+ EVALUATION:
100
+ NAME: 'waymo_motion_prediction'
101
+ PREDICT_ALL_AGENTS: false
102
+ DELETE_EVAL_RESULT: true
103
+ NUM_MODES: 6
104
+ MAXIMUM_BATCH_SIZE: 10000
105
+ USE_CACHE: true
106
+ USE_TG_AS_GT: 1111
107
+ TG_REJECT_SAMPLING: True
108
+ TG_SDC_DISTANCE_MASKING: False
109
+
110
+ MODEL:
111
+ NAME: 'motionlm'
112
+ D_MODEL: 256
113
+ NUM_ATTN_LAYERS: 4
114
+ NUM_ATTN_HEAD: 8
115
+ # DROPOUT_OF_ATTN: 0.0
116
+ DROPOUT: 0.0
117
+ NUM_DECODER_LAYERS: 6
118
+ ADD_PE_FOR_TOKEN: true
119
+ RELATIVE_PE: true
120
+ RELATIVE_PE_DECODER: false
121
+ PRE_PROJECTION: false
122
+ KNN: 128
123
+ S2S_DISTANCE: null
124
+ SELF_ATTN_KNN: 128
125
+ CROSS_ATTN_KNN: 128
126
+ RANDOMIZE_AGENT_ID: true
127
+ A2S_KNN: null
128
+ A2S_DISTANCE: null
129
+ A2A_KNN: null
130
+ A2A_DISTANCE: null
131
+ ADD_RELATION_TO_V: false
132
+ IS_V7: False
133
+ PER_CONTOUR_POINT_RELATION: null
134
+
135
+ TOKENIZATION:
136
+ TOKENIZATION_METHOD: delta_delta
137
+ NUM_SKIPPED_STEPS: 5
138
+ NUM_BINS: 13
139
+ X_MAX: 3.5 # <<< Deprecated
140
+ X_MIN: -3.5 # <<< Deprecated
141
+ Y_MAX: 3.5 # <<< Deprecated
142
+ Y_MIN: -3.5 # <<< Deprecated
143
+ ADD_NOISE: false
144
+ NOISE_TOPK: 5
145
+ ALLOW_SKIP_STEP: True
146
+
147
+ MIN_DISPLACEMENT: 0.1
148
+ MIN_DISPLACEMENT_INIT: null
149
+ MIN_SPEED: null
150
+ SMOOTH_FACTOR: null
151
+ MAX_HEADING_DIFF: null
152
+
153
+ USE_CONTOUR_ERROR: True
154
+
155
+ VEH_LIMIT: 3.5
156
+ PED_LIMIT: 3.5
157
+ CYC_LIMIT: 3.5
158
+
159
+ FLIP_WRONG_HEADING: false
160
+ SHOULD_STANDARDIZE: true
161
+
162
+ # MIN_DISPLACEMENT: 0.3
163
+ # MIN_DISPLACEMENT_INIT: 1.0
164
+ # MIN_SPEED: 0.5
165
+ # SMOOTH_FACTOR: null
166
+ # MAX_HEADING_DIFF: 0.3
167
+
168
+
169
+ SAMPLING:
170
+ SAMPLING_METHOD: 'topp'
171
+ TEMPERATURE: 1.0
172
+ TOPP: 0.95
173
+
174
+
175
+ OPTIMIZATION:
176
+ # NUM_EPOCHS: 50
177
+ OPTIMIZER: AdamW
178
+ LR: 0.0003
179
+ WEIGHT_DECAY: 0.0
180
+ GRAD_NORM_CLIP: 1.0
181
+ SCHEDULER: cosine
182
+ WARMUP_STEPS: 2000
183
+ # TRAINING_STEPS: 300000
184
+ USE_FOCAL_LOSS: false
185
+
186
+
187
+ SUBMISSION:
188
+ GENERATE_SUBMISSION: false
189
+ PREFIX: "peng"
190
+ ACCOUNT: "dr.zhenghao.peng@gmail.com"
191
+ METHOD_NAME: "peng"
192
+ num_model_parameters: '10m' # TODO: Need to be changed accordingly!
193
+ SAVE_EVAL_DATA: true
194
+
195
+ TMP_DIR: "tmp" # Relative to repo root
196
+
197
+ ACTION_LABEL:
198
+ USE_ACTION_LABEL: false # Only valid for turning + acceleration
199
+ USE_SAFETY_LABEL: false
200
+ MASK_PROBABILITY_ACTION_LABEL: 0.0 # Might turn it on
201
+ MASK_PROBABILITY_SAFETY_LABEL: 0.0 # Might turn it on
202
+
203
+ LANGUAGE_CONDITION: false
204
+ FINE_TUNE_BERT: false
205
+
206
+ MCTS:
207
+ USE_MCTS: False
208
+ MCTS_DEPTH: -1
209
+ MCTS_WIDTH: -1
210
+
211
+ TOKEN_BUFFER_CACHE_LENGTH: 100
cfgs/scenestreamer-base-large.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - motion_default
3
+ - _self_
4
+
5
+ exp_name: 'scenestreamer-base-large'
6
+ pretrain: "/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/20250506_scenestreamer_v18_notg_large_FIXETYPE_2025-05-06/checkpoints"
7
+
8
+ num_workers: 8
9
+ val_num_workers: 8
10
+ num_sanity_val_steps: 10
11
+
12
+ batch_size: 4
13
+ val_batch_size: 4
14
+ limit_val_batches: -1
15
+
16
+ eval_backward_model: False
17
+
18
+ epochs: 30
19
+ wandb: True
20
+ log_dir: /bigdata/zhenghao/scenestreamer
21
+
22
+ SCENESTREAMER_ATTENTION_KNN: 128
23
+ SCENESTREAMER_ATTENTION_MAX_DISTANCE: 50
24
+ SCENESTREAMER_NO_TG: true
25
+
26
+ REMOVE_AGENT_FROM_SCENE_ENCODER: True # <<<
27
+
28
+ BACKWARD_PREDICTION: False # <<<
29
+ ADD_CONTOUR_RELATION: True # <<<
30
+
31
+ DELTA_POS_IS_VELOCITY: True
32
+ SIMPLE_RELATION: True
33
+
34
+ RECONSTRUCT_MAP: False
35
+ UPDATE_RELATION: False
36
+ REMOVE_REL_NORM: False # <<<
37
+
38
+ USE_TRAFFICGEN: True
39
+ USE_MOTION: True
40
+ EVAL_MOTION: True
41
+ EVAL_TRAFFICGEN: False
42
+
43
+ GPT_STYLE: True # <<<
44
+ USE_ADALN: False
45
+
46
+ SAMPLING:
47
+ TOPP: 0.95
48
+ TEMPERATURE: 1.0
49
+
50
+ TOKENIZATION:
51
+ TOKENIZATION_METHOD: "BicycleModelTokenizerFixed0124" # <<<
52
+ USE_CONTOUR_ERROR: True # <<<
53
+ ALLOW_SKIP_STEP: True
54
+ ADD_NOISE: False
55
+ NUM_BINS: 33
56
+
57
+ PREPROCESSING:
58
+ REMOVE_TRAFFIC_LIGHT_STATE: False
59
+ MAX_LENGTH_PER_MAP_FEATURE: 10
60
+ MAX_MAP_FEATURES: 3000
61
+ MAX_VECTORS: 30
62
+ MAX_AGENTS: 128
63
+ DEST_DROPOUT: 0.0
64
+ ADD_SDC_TO_OBJECT_OF_INTEREST: False
65
+
66
+ DATA:
67
+ TRAINING_DATA_DIR: ''
68
+ TEST_DATA_DIR: ''
69
+
70
+ MODEL:
71
+ USE_MOTION_HEAD_PRENORM: True
72
+ ALL_TO_MAP_3D: False
73
+ D_MODEL: 128
74
+ NAME: 'scenestreamer'
75
+ NUM_ATTN_HEAD: 4 # TODO: Need to increase? was 8
76
+ # Encoder:
77
+ NUM_ATTN_LAYERS: 2 # TODO: Need to increase? was 3
78
+ RELATIVE_PE: true
79
+ # Decoder:
80
+ NUM_DECODER_LAYERS: 4 # TODO: Need to increase? was 6
81
+ RELATIVE_PE_DECODER: True
82
+ SIMPLE_RELATION_FACTOR: 1
83
+ # New:
84
+ KNN: -100
85
+ S2S_DISTANCE: -100
86
+ A2S_KNN: -100
87
+ A2S_DISTANCE: -100
88
+ A2A_KNN: -100
89
+ A2A_DISTANCE: -100
90
+ ADD_RELATION_TO_V: False
91
+ PER_CONTOUR_POINT_RELATION: False
92
+ IS_V7: True
93
+
94
+ SUBMISSION:
95
+ METHOD_NAME: "scenestreamer-base-large"
96
+ num_model_parameters: '3.3m'
cfgs/scenestreamer-base-small.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - motion_default
3
+ - _self_
4
+
5
+ exp_name: 'scenestreamer-base-small'
6
+ pretrain: "/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/20250507_scenestreamer_v17_notg_finetune_FIXTYPE_2025-05-07/checkpoints"
7
+
8
+ num_workers: 8
9
+ val_num_workers: 8
10
+ num_sanity_val_steps: 10
11
+
12
+ batch_size: 4
13
+ val_batch_size: 4
14
+ limit_val_batches: -1
15
+
16
+ eval_backward_model: False
17
+
18
+ epochs: 30
19
+ wandb: True
20
+ log_dir: /bigdata/zhenghao/scenestreamer
21
+
22
+ SCENESTREAMER_ATTENTION_KNN: 128
23
+ SCENESTREAMER_ATTENTION_MAX_DISTANCE: 50
24
+ SCENESTREAMER_NO_TG: true
25
+
26
+ REMOVE_AGENT_FROM_SCENE_ENCODER: True # <<<
27
+
28
+ BACKWARD_PREDICTION: False # <<<
29
+ ADD_CONTOUR_RELATION: True # <<<
30
+
31
+ DELTA_POS_IS_VELOCITY: True
32
+ SIMPLE_RELATION: True
33
+
34
+ RECONSTRUCT_MAP: False
35
+ UPDATE_RELATION: False
36
+ REMOVE_REL_NORM: False # <<<
37
+
38
+ USE_TRAFFICGEN: True
39
+ USE_MOTION: True
40
+ EVAL_MOTION: True
41
+ EVAL_TRAFFICGEN: False
42
+
43
+ GPT_STYLE: True # <<<
44
+ USE_ADALN: False
45
+
46
+ SAMPLING:
47
+ TOPP: 0.95
48
+ TEMPERATURE: 1.0
49
+
50
+ TOKENIZATION:
51
+ TOKENIZATION_METHOD: "BicycleModelTokenizerFixed0124" # <<<
52
+ USE_CONTOUR_ERROR: True # <<<
53
+ ALLOW_SKIP_STEP: True
54
+ ADD_NOISE: False
55
+ NUM_BINS: 33
56
+
57
+ PREPROCESSING:
58
+ REMOVE_TRAFFIC_LIGHT_STATE: False
59
+ MAX_LENGTH_PER_MAP_FEATURE: 10
60
+ MAX_MAP_FEATURES: 3000
61
+ MAX_VECTORS: 30
62
+ MAX_AGENTS: 128
63
+ DEST_DROPOUT: 0.0
64
+ ADD_SDC_TO_OBJECT_OF_INTEREST: False
65
+
66
+ DATA:
67
+ TRAINING_DATA_DIR: ''
68
+ TEST_DATA_DIR: ''
69
+
70
+ MODEL:
71
+ USE_MOTION_HEAD_PRENORM: True
72
+ ALL_TO_MAP_3D: False
73
+ D_MODEL: 64
74
+ NAME: 'scenestreamer'
75
+ NUM_ATTN_HEAD: 4 # TODO: Need to increase? was 8
76
+ # Encoder:
77
+ NUM_ATTN_LAYERS: 2 # TODO: Need to increase? was 3
78
+ RELATIVE_PE: true
79
+ # Decoder:
80
+ NUM_DECODER_LAYERS: 4 # TODO: Need to increase? was 6
81
+ RELATIVE_PE_DECODER: True
82
+ SIMPLE_RELATION_FACTOR: 1
83
+ # New:
84
+ KNN: -100
85
+ S2S_DISTANCE: -100
86
+ A2S_KNN: -100
87
+ A2S_DISTANCE: -100
88
+ A2A_KNN: -100
89
+ A2A_DISTANCE: -100
90
+ ADD_RELATION_TO_V: False
91
+ PER_CONTOUR_POINT_RELATION: False
92
+ IS_V7: True
93
+
94
+ SUBMISSION:
95
+ METHOD_NAME: "scenestreamer-base-small"
96
+ num_model_parameters: '1.1m'
cfgs/scenestreamer-base-xl.yaml ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - motion_default
3
+ - _self_
4
+
5
+ exp_name: 'scenestreamer-base-xl'
6
+ pretrain: "/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/20250512_scenestreamer-base-xl_2025-05-12/checkpoints"
7
+
8
+ num_workers: 8
9
+ val_num_workers: 8
10
+ num_sanity_val_steps: 10
11
+
12
+ batch_size: 4
13
+ val_batch_size: 4
14
+ limit_val_batches: -1
15
+
16
+ eval_backward_model: False
17
+
18
+ epochs: 30
19
+ wandb: True
20
+ log_dir: /bigdata/zhenghao/scenestreamer
21
+
22
+ SCENESTREAMER_ATTENTION_KNN: 128
23
+ SCENESTREAMER_ATTENTION_MAX_DISTANCE: 50
24
+ SCENESTREAMER_NO_TG: true
25
+
26
+ REMOVE_AGENT_FROM_SCENE_ENCODER: True # <<<
27
+
28
+ BACKWARD_PREDICTION: False # <<<
29
+ ADD_CONTOUR_RELATION: True # <<<
30
+
31
+ DELTA_POS_IS_VELOCITY: True
32
+ SIMPLE_RELATION: True
33
+
34
+ RECONSTRUCT_MAP: False
35
+ UPDATE_RELATION: False
36
+ REMOVE_REL_NORM: False # <<<
37
+
38
+ USE_TRAFFICGEN: True
39
+ USE_MOTION: True
40
+ EVAL_MOTION: True
41
+ EVAL_TRAFFICGEN: False
42
+
43
+ GPT_STYLE: True # <<<
44
+ USE_ADALN: False
45
+
46
+ SAMPLING:
47
+ TOPP: 0.95
48
+ TEMPERATURE: 1.0
49
+
50
+ TOKENIZATION:
51
+ TOKENIZATION_METHOD: "BicycleModelTokenizerFixed0124" # <<<
52
+ USE_CONTOUR_ERROR: True # <<<
53
+ ALLOW_SKIP_STEP: True
54
+ ADD_NOISE: False
55
+ NUM_BINS: 33
56
+
57
+ PREPROCESSING:
58
+ REMOVE_TRAFFIC_LIGHT_STATE: False
59
+ MAX_LENGTH_PER_MAP_FEATURE: 10
60
+ MAX_MAP_FEATURES: 3000
61
+ MAX_VECTORS: 30
62
+ MAX_AGENTS: 128
63
+ DEST_DROPOUT: 0.0
64
+ ADD_SDC_TO_OBJECT_OF_INTEREST: False
65
+
66
+ DATA:
67
+ TRAINING_DATA_DIR: ''
68
+ TEST_DATA_DIR: ''
69
+
70
+ MODEL:
71
+ USE_MOTION_HEAD_PRENORM: True
72
+ ALL_TO_MAP_3D: False
73
+ D_MODEL: 128
74
+ NAME: 'scenestreamer'
75
+ NUM_ATTN_HEAD: 8
76
+ # Encoder:
77
+ NUM_ATTN_LAYERS: 3
78
+ RELATIVE_PE: true
79
+ # Decoder:
80
+ NUM_DECODER_LAYERS: 6
81
+ RELATIVE_PE_DECODER: True
82
+ SIMPLE_RELATION_FACTOR: 1
83
+ # New:
84
+ KNN: -100
85
+ S2S_DISTANCE: -100
86
+ A2S_KNN: -100
87
+ A2S_DISTANCE: -100
88
+ A2A_KNN: -100
89
+ A2A_DISTANCE: -100
90
+ ADD_RELATION_TO_V: False
91
+ PER_CONTOUR_POINT_RELATION: False
92
+ IS_V7: True
93
+
94
+ SUBMISSION:
95
+ METHOD_NAME: "scenestreamer-base-xl"
96
+ num_model_parameters: '4.2m'
97
+ ACCOUNT: "dr.zhenghao.peng@gmail.com"
cfgs/scenestreamer-full-large-nors.yaml ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - motion_default
3
+ - _self_
4
+
5
+ exp_name: 'scenestreamer-full-large-nors'
6
+ pretrain: "/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/20250507_scenestreamer-full-large_2025-05-07/checkpoints/20250507_scenestreamer-full-large_2025-05-07_epoch=1-step=77031.ckpt"
7
+
8
+ num_workers: 8
9
+ val_num_workers: 8
10
+ num_sanity_val_steps: 10
11
+
12
+ batch_size: 4
13
+ val_batch_size: 4
14
+ limit_val_batches: -1
15
+
16
+ eval_backward_model: False
17
+
18
+ epochs: 30
19
+ wandb: True
20
+ log_dir: /bigdata/zhenghao/scenestreamer
21
+
22
+ SCENESTREAMER_ATTENTION_KNN: 128
23
+ SCENESTREAMER_ATTENTION_MAX_DISTANCE: 50
24
+ SCENESTREAMER_NO_TG: false
25
+
26
+ REMOVE_AGENT_FROM_SCENE_ENCODER: True # <<<
27
+
28
+ BACKWARD_PREDICTION: False # <<<
29
+ ADD_CONTOUR_RELATION: True # <<<
30
+
31
+ DELTA_POS_IS_VELOCITY: True
32
+ SIMPLE_RELATION: True
33
+
34
+ RECONSTRUCT_MAP: False
35
+ UPDATE_RELATION: False
36
+ REMOVE_REL_NORM: False # <<<
37
+
38
+ USE_TRAFFICGEN: True
39
+ USE_MOTION: True
40
+ EVAL_MOTION: True
41
+ EVAL_TRAFFICGEN: False
42
+
43
+ GPT_STYLE: True # <<<
44
+ USE_ADALN: False
45
+
46
+ SAMPLING:
47
+ TOPP: 0.95
48
+ TEMPERATURE: 1.0
49
+
50
+ TOKENIZATION:
51
+ TOKENIZATION_METHOD: "BicycleModelTokenizerFixed0124" # <<<
52
+ USE_CONTOUR_ERROR: True # <<<
53
+ ALLOW_SKIP_STEP: True
54
+ ADD_NOISE: False
55
+ NUM_BINS: 33
56
+
57
+ PREPROCESSING:
58
+ REMOVE_TRAFFIC_LIGHT_STATE: False
59
+ MAX_LENGTH_PER_MAP_FEATURE: 10
60
+ MAX_MAP_FEATURES: 3000
61
+ MAX_VECTORS: 30
62
+ MAX_AGENTS: 128
63
+ DEST_DROPOUT: 0.0
64
+ ADD_SDC_TO_OBJECT_OF_INTEREST: False
65
+
66
+ DATA:
67
+ TRAINING_DATA_DIR: ''
68
+ TEST_DATA_DIR: ''
69
+
70
+ MODEL:
71
+ USE_MOTION_HEAD_PRENORM: True
72
+ ALL_TO_MAP_3D: False
73
+ D_MODEL: 128
74
+ NAME: 'scenestreamer'
75
+ NUM_ATTN_HEAD: 4 # TODO: Need to increase? was 8
76
+ # Encoder:
77
+ NUM_ATTN_LAYERS: 2 # TODO: Need to increase? was 3
78
+ RELATIVE_PE: true
79
+ # Decoder:
80
+ NUM_DECODER_LAYERS: 4 # TODO: Need to increase? was 6
81
+ RELATIVE_PE_DECODER: True
82
+ SIMPLE_RELATION_FACTOR: 1
83
+ # New:
84
+ KNN: -100
85
+ S2S_DISTANCE: -100
86
+ A2S_KNN: -100
87
+ A2S_DISTANCE: -100
88
+ A2A_KNN: -100
89
+ A2A_DISTANCE: -100
90
+ ADD_RELATION_TO_V: False
91
+ PER_CONTOUR_POINT_RELATION: False
92
+ IS_V7: True
93
+
94
+ SUBMISSION:
95
+ METHOD_NAME: "scenestreamer-full-large"
96
+ num_model_parameters: '4.6m'
97
+
98
+ EVALUATION:
99
+ TG_REJECT_SAMPLING: False
cfgs/scenestreamer-full-large.yaml ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - motion_default
3
+ - _self_
4
+
5
+ exp_name: 'scenestreamer-full-large'
6
+ pretrain: "/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/20250507_scenestreamer-full-large_2025-05-07/checkpoints/20250507_scenestreamer-full-large_2025-05-07_epoch=1-step=77031.ckpt"
7
+
8
+ num_workers: 8
9
+ val_num_workers: 8
10
+ num_sanity_val_steps: 10
11
+
12
+ batch_size: 4
13
+ val_batch_size: 4
14
+ limit_val_batches: -1
15
+
16
+ eval_backward_model: False
17
+
18
+ epochs: 30
19
+ wandb: True
20
+ log_dir: /bigdata/zhenghao/scenestreamer
21
+
22
+ SCENESTREAMER_ATTENTION_KNN: 128
23
+ SCENESTREAMER_ATTENTION_MAX_DISTANCE: 50
24
+ SCENESTREAMER_NO_TG: false
25
+
26
+ REMOVE_AGENT_FROM_SCENE_ENCODER: True # <<<
27
+
28
+ BACKWARD_PREDICTION: False # <<<
29
+ ADD_CONTOUR_RELATION: True # <<<
30
+
31
+ DELTA_POS_IS_VELOCITY: True
32
+ SIMPLE_RELATION: True
33
+
34
+ RECONSTRUCT_MAP: False
35
+ UPDATE_RELATION: False
36
+ REMOVE_REL_NORM: False # <<<
37
+
38
+ USE_TRAFFICGEN: True
39
+ USE_MOTION: True
40
+ EVAL_MOTION: True
41
+ EVAL_TRAFFICGEN: False
42
+
43
+ GPT_STYLE: True # <<<
44
+ USE_ADALN: False
45
+
46
+ SAMPLING:
47
+ TOPP: 0.95
48
+ TEMPERATURE: 1.0
49
+
50
+ TOKENIZATION:
51
+ TOKENIZATION_METHOD: "BicycleModelTokenizerFixed0124" # <<<
52
+ USE_CONTOUR_ERROR: True # <<<
53
+ ALLOW_SKIP_STEP: True
54
+ ADD_NOISE: False
55
+ NUM_BINS: 33
56
+
57
+ PREPROCESSING:
58
+ REMOVE_TRAFFIC_LIGHT_STATE: False
59
+ MAX_LENGTH_PER_MAP_FEATURE: 10
60
+ MAX_MAP_FEATURES: 3000
61
+ MAX_VECTORS: 30
62
+ MAX_AGENTS: 128
63
+ DEST_DROPOUT: 0.0
64
+ ADD_SDC_TO_OBJECT_OF_INTEREST: False
65
+
66
+ DATA:
67
+ TRAINING_DATA_DIR: ''
68
+ TEST_DATA_DIR: ''
69
+
70
+ MODEL:
71
+ USE_MOTION_HEAD_PRENORM: True
72
+ ALL_TO_MAP_3D: False
73
+ D_MODEL: 128
74
+ NAME: 'scenestreamer'
75
+ NUM_ATTN_HEAD: 4 # TODO: Need to increase? was 8
76
+ # Encoder:
77
+ NUM_ATTN_LAYERS: 2 # TODO: Need to increase? was 3
78
+ RELATIVE_PE: true
79
+ # Decoder:
80
+ NUM_DECODER_LAYERS: 4 # TODO: Need to increase? was 6
81
+ RELATIVE_PE_DECODER: True
82
+ SIMPLE_RELATION_FACTOR: 1
83
+ # New:
84
+ KNN: -100
85
+ S2S_DISTANCE: -100
86
+ A2S_KNN: -100
87
+ A2S_DISTANCE: -100
88
+ A2A_KNN: -100
89
+ A2A_DISTANCE: -100
90
+ ADD_RELATION_TO_V: False
91
+ PER_CONTOUR_POINT_RELATION: False
92
+ IS_V7: True
93
+
94
+ SUBMISSION:
95
+ METHOD_NAME: "scenestreamer-full-large"
96
+ num_model_parameters: '4.6m'
97
+
98
+ EVALUATION:
99
+ TG_REJECT_SAMPLING: True
cfgs/scenestreamer-full-small.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - motion_default
3
+ - _self_
4
+
5
+ exp_name: 'scenestreamer-full-small'
6
+ pretrain: "/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/20250505_scenestreamer_v19_withtg_nodest_FIXEDAS_2025-05-05/checkpoints"
7
+
8
+ num_workers: 8
9
+ val_num_workers: 8
10
+ num_sanity_val_steps: 10
11
+
12
+ batch_size: 4
13
+ val_batch_size: 4
14
+ limit_val_batches: -1
15
+
16
+ eval_backward_model: False
17
+
18
+ epochs: 30
19
+ wandb: True
20
+ log_dir: /bigdata/zhenghao/scenestreamer
21
+
22
+ SCENESTREAMER_ATTENTION_KNN: 128
23
+ SCENESTREAMER_ATTENTION_MAX_DISTANCE: 50
24
+ SCENESTREAMER_NO_TG: false
25
+
26
+ REMOVE_AGENT_FROM_SCENE_ENCODER: True # <<<
27
+
28
+ BACKWARD_PREDICTION: False # <<<
29
+ ADD_CONTOUR_RELATION: True # <<<
30
+
31
+ DELTA_POS_IS_VELOCITY: True
32
+ SIMPLE_RELATION: True
33
+
34
+ RECONSTRUCT_MAP: False
35
+ UPDATE_RELATION: False
36
+ REMOVE_REL_NORM: False # <<<
37
+
38
+ USE_TRAFFICGEN: True
39
+ USE_MOTION: True
40
+ EVAL_MOTION: True
41
+ EVAL_TRAFFICGEN: False
42
+
43
+ GPT_STYLE: True # <<<
44
+ USE_ADALN: False
45
+
46
+ SAMPLING:
47
+ TOPP: 0.95
48
+ TEMPERATURE: 1.0
49
+
50
+ TOKENIZATION:
51
+ TOKENIZATION_METHOD: "BicycleModelTokenizerFixed0124" # <<<
52
+ USE_CONTOUR_ERROR: True # <<<
53
+ ALLOW_SKIP_STEP: True
54
+ ADD_NOISE: False
55
+ NUM_BINS: 33
56
+
57
+ PREPROCESSING:
58
+ REMOVE_TRAFFIC_LIGHT_STATE: False
59
+ MAX_LENGTH_PER_MAP_FEATURE: 10
60
+ MAX_MAP_FEATURES: 3000
61
+ MAX_VECTORS: 30
62
+ MAX_AGENTS: 128
63
+ DEST_DROPOUT: 0.0
64
+ ADD_SDC_TO_OBJECT_OF_INTEREST: False
65
+
66
+ DATA:
67
+ TRAINING_DATA_DIR: ''
68
+ TEST_DATA_DIR: ''
69
+
70
+ MODEL:
71
+ USE_MOTION_HEAD_PRENORM: True
72
+ ALL_TO_MAP_3D: False
73
+ D_MODEL: 64 # TODO: Need to increase? was 128
74
+ NAME: 'scenestreamer'
75
+ NUM_ATTN_HEAD: 4 # TODO: Need to increase? was 8
76
+ # Encoder:
77
+ NUM_ATTN_LAYERS: 2 # TODO: Need to increase? was 3
78
+ RELATIVE_PE: true
79
+ # Decoder:
80
+ NUM_DECODER_LAYERS: 4 # TODO: Need to increase? was 6
81
+ RELATIVE_PE_DECODER: True
82
+ SIMPLE_RELATION_FACTOR: 1
83
+ # New:
84
+ KNN: -100
85
+ S2S_DISTANCE: -100
86
+ A2S_KNN: -100
87
+ A2S_DISTANCE: -100
88
+ A2A_KNN: -100
89
+ A2A_DISTANCE: -100
90
+ ADD_RELATION_TO_V: False
91
+ PER_CONTOUR_POINT_RELATION: False
92
+ IS_V7: True
93
+
94
+ SUBMISSION:
95
+ METHOD_NAME: "scenestreamer-full-small"
96
+ num_model_parameters: '1.5m'
cfgs/scenestreamer-full-xl.yaml ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - motion_default
3
+ - _self_
4
+
5
+ exp_name: 'scenestreamer-full-xl'
6
+ pretrain: "/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/20250518_scenestreamer-full-xs_2025-05-18/checkpoints/20250518_scenestreamer-full-xs_2025-05-18_epoch=1-step=62494.ckpt"
7
+
8
+ num_workers: 8
9
+ val_num_workers: 8
10
+ num_sanity_val_steps: 10
11
+
12
+ batch_size: 4
13
+ val_batch_size: 4
14
+ limit_val_batches: -1
15
+
16
+ eval_backward_model: False
17
+
18
+ epochs: 30
19
+ wandb: True
20
+ log_dir: /bigdata/zhenghao/scenestreamer
21
+
22
+ SCENESTREAMER_ATTENTION_KNN: 128
23
+ SCENESTREAMER_ATTENTION_MAX_DISTANCE: 50
24
+ SCENESTREAMER_NO_TG: false
25
+
26
+ REMOVE_AGENT_FROM_SCENE_ENCODER: True # <<<
27
+
28
+ BACKWARD_PREDICTION: False # <<<
29
+ ADD_CONTOUR_RELATION: True # <<<
30
+
31
+ DELTA_POS_IS_VELOCITY: True
32
+ SIMPLE_RELATION: True
33
+
34
+ RECONSTRUCT_MAP: False
35
+ UPDATE_RELATION: False
36
+ REMOVE_REL_NORM: False # <<<
37
+
38
+ USE_TRAFFICGEN: True
39
+ USE_MOTION: True
40
+ EVAL_MOTION: True
41
+ EVAL_TRAFFICGEN: False
42
+
43
+ GPT_STYLE: True # <<<
44
+ USE_ADALN: False
45
+
46
+ SAMPLING:
47
+ TOPP: 0.95
48
+ TEMPERATURE: 1.0
49
+
50
+ TOKENIZATION:
51
+ TOKENIZATION_METHOD: "BicycleModelTokenizerFixed0124" # <<<
52
+ USE_CONTOUR_ERROR: True # <<<
53
+ ALLOW_SKIP_STEP: True
54
+ ADD_NOISE: False
55
+ NUM_BINS: 33
56
+
57
+ PREPROCESSING:
58
+ REMOVE_TRAFFIC_LIGHT_STATE: False
59
+ MAX_LENGTH_PER_MAP_FEATURE: 10
60
+ MAX_MAP_FEATURES: 3000
61
+ MAX_VECTORS: 30
62
+ MAX_AGENTS: 128
63
+ DEST_DROPOUT: 0.0
64
+ ADD_SDC_TO_OBJECT_OF_INTEREST: False
65
+
66
+ DATA:
67
+ TRAINING_DATA_DIR: ''
68
+ TEST_DATA_DIR: ''
69
+
70
+ MODEL:
71
+ USE_MOTION_HEAD_PRENORM: True
72
+ ALL_TO_MAP_3D: False
73
+ D_MODEL: 128
74
+ NAME: 'scenestreamer'
75
+ NUM_ATTN_HEAD: 8
76
+ # Encoder:
77
+ NUM_ATTN_LAYERS: 3
78
+ RELATIVE_PE: true
79
+ # Decoder:
80
+ NUM_DECODER_LAYERS: 6
81
+ RELATIVE_PE_DECODER: True
82
+ SIMPLE_RELATION_FACTOR: 1
83
+ # New:
84
+ KNN: -100
85
+ S2S_DISTANCE: -100
86
+ A2S_KNN: -100
87
+ A2S_DISTANCE: -100
88
+ A2A_KNN: -100
89
+ A2A_DISTANCE: -100
90
+ ADD_RELATION_TO_V: False
91
+ PER_CONTOUR_POINT_RELATION: False
92
+ IS_V7: True
93
+
94
+ SUBMISSION:
95
+ METHOD_NAME: "scenestreamer-full-xl"
96
+ num_model_parameters: '5.5m'
97
+ ACCOUNT: "dr.zhenghao.peng@gmail.com"
98
+
99
+ EVALUATION:
100
+ TG_REJECT_SAMPLING: True
data/20scenarios/cache/sd_training.tfrecord-00000-of-01000_101d4e5775093d0c.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d2f5a926ba159d4e9acec464c9d091c093d138a21796ce5c264fea7f4398a777
3
+ size 3007314
data/20scenarios/cache/sd_training.tfrecord-00000-of-01000_16d4f837fb57ff1.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b8a001efb1d464a0f9a3c06f26cf921f2308e81b26a65741491875201ede70b1
3
+ size 6364095
data/20scenarios/cache/sd_training.tfrecord-00000-of-01000_17100879b93ceb61.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:742aa8c350793d83e949396dcb055a17bcbdd4b2b728b41d5f1b5c6c5a897ce1
3
+ size 4382994
data/20scenarios/cache/sd_training.tfrecord-00000-of-01000_c403d5992cab9e0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:716fa8212d8d4dbed60703cf5a9a952129fab5ed9342da748a7a48f00478e6d9
3
+ size 11279523
data/20scenarios/process.ipynb ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "f003e6e4",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import os, pickle"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 4,
16
+ "id": "fdb6f299",
17
+ "metadata": {},
18
+ "outputs": [
19
+ {
20
+ "name": "stdout",
21
+ "output_type": "stream",
22
+ "text": [
23
+ "\u001b[0m\u001b[01;32mdataset_summary.pkl\u001b[0m*\r\n",
24
+ "\u001b[01;32mprocess.ipynb\u001b[0m*\r\n",
25
+ "\u001b[01;32msd_training.tfrecord-00000-of-01000_101d4e5775093d0c.pkl\u001b[0m*\r\n",
26
+ "\u001b[01;32msd_training.tfrecord-00000-of-01000_10c3969f1eb158d.pkl\u001b[0m*\r\n",
27
+ "\u001b[01;32msd_training.tfrecord-00000-of-01000_1109b0038ed8f25a.pkl\u001b[0m*\r\n",
28
+ "\u001b[01;32msd_training.tfrecord-00000-of-01000_116d257e98878d94.pkl\u001b[0m*\r\n",
29
+ "\u001b[01;32msd_training.tfrecord-00000-of-01000_11cdace2c1445900.pkl\u001b[0m*\r\n",
30
+ "\u001b[01;32msd_training.tfrecord-00000-of-01000_12a40f114f1ec5fc.pkl\u001b[0m*\r\n",
31
+ "\u001b[01;32msd_training.tfrecord-00000-of-01000_149f682e19454efa.pkl\u001b[0m*\r\n",
32
+ "\u001b[01;32msd_training.tfrecord-00000-of-01000_16361c8c522cf0e.pkl\u001b[0m*\r\n",
33
+ "\u001b[01;32msd_training.tfrecord-00000-of-01000_1663adf01133d82.pkl\u001b[0m*\r\n",
34
+ "\u001b[01;32msd_training.tfrecord-00000-of-01000_16b7c9a8ae6e89f1.pkl\u001b[0m*\r\n",
35
+ "\u001b[01;32msd_training.tfrecord-00000-of-01000_16d4f837fb57ff1.pkl\u001b[0m*\r\n",
36
+ "\u001b[01;32msd_training.tfrecord-00000-of-01000_17100879b93ceb61.pkl\u001b[0m*\r\n",
37
+ "\u001b[01;32msd_training.tfrecord-00000-of-01000_17fa718cd1251a26.pkl\u001b[0m*\r\n",
38
+ "\u001b[01;32msd_training.tfrecord-00000-of-01000_18840a098288507f.pkl\u001b[0m*\r\n",
39
+ "\u001b[01;32msd_training.tfrecord-00000-of-01000_18f5ce249ee2e949.pkl\u001b[0m*\r\n",
40
+ "\u001b[01;32msd_training.tfrecord-00000-of-01000_1a7143a44e480ca6.pkl\u001b[0m*\r\n",
41
+ "\u001b[01;32msd_training.tfrecord-00000-of-01000_1a8cc570d620bd31.pkl\u001b[0m*\r\n",
42
+ "\u001b[01;32msd_training.tfrecord-00000-of-01000_2a1e44d405a6833f.pkl\u001b[0m*\r\n",
43
+ "\u001b[01;32msd_training.tfrecord-00000-of-01000_8a346109094cd5aa.pkl\u001b[0m*\r\n",
44
+ "\u001b[01;32msd_training.tfrecord-00000-of-01000_c403d5992cab9e0.pkl\u001b[0m*\r\n"
45
+ ]
46
+ }
47
+ ],
48
+ "source": [
49
+ "ls"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": 3,
55
+ "id": "f08caf3b",
56
+ "metadata": {},
57
+ "outputs": [
58
+ {
59
+ "data": {
60
+ "text/plain": [
61
+ "['sd_training.tfrecord-00000-of-01000_101d4e5775093d0c.pkl',\n",
62
+ " 'sd_training.tfrecord-00000-of-01000_10c3969f1eb158d.pkl',\n",
63
+ " 'sd_training.tfrecord-00000-of-01000_1109b0038ed8f25a.pkl',\n",
64
+ " 'sd_training.tfrecord-00000-of-01000_116d257e98878d94.pkl',\n",
65
+ " 'sd_training.tfrecord-00000-of-01000_11cdace2c1445900.pkl',\n",
66
+ " 'sd_training.tfrecord-00000-of-01000_12a40f114f1ec5fc.pkl',\n",
67
+ " 'sd_training.tfrecord-00000-of-01000_149f682e19454efa.pkl',\n",
68
+ " 'sd_training.tfrecord-00000-of-01000_16361c8c522cf0e.pkl',\n",
69
+ " 'sd_training.tfrecord-00000-of-01000_1663adf01133d82.pkl',\n",
70
+ " 'sd_training.tfrecord-00000-of-01000_16b7c9a8ae6e89f1.pkl',\n",
71
+ " 'sd_training.tfrecord-00000-of-01000_16d4f837fb57ff1.pkl',\n",
72
+ " 'sd_training.tfrecord-00000-of-01000_17100879b93ceb61.pkl',\n",
73
+ " 'sd_training.tfrecord-00000-of-01000_17fa718cd1251a26.pkl',\n",
74
+ " 'sd_training.tfrecord-00000-of-01000_18840a098288507f.pkl',\n",
75
+ " 'sd_training.tfrecord-00000-of-01000_18f5ce249ee2e949.pkl',\n",
76
+ " 'sd_training.tfrecord-00000-of-01000_1a7143a44e480ca6.pkl',\n",
77
+ " 'sd_training.tfrecord-00000-of-01000_1a8cc570d620bd31.pkl',\n",
78
+ " 'sd_training.tfrecord-00000-of-01000_2a1e44d405a6833f.pkl',\n",
79
+ " 'sd_training.tfrecord-00000-of-01000_8a346109094cd5aa.pkl',\n",
80
+ " 'sd_training.tfrecord-00000-of-01000_c403d5992cab9e0.pkl']"
81
+ ]
82
+ },
83
+ "execution_count": 3,
84
+ "metadata": {},
85
+ "output_type": "execute_result"
86
+ }
87
+ ],
88
+ "source": [
89
+ "[p for p in os.listdir(\".\") if p.endswith(\".pkl\") and p.startswith(\"sd\")]"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": 6,
95
+ "id": "d1d50b21",
96
+ "metadata": {},
97
+ "outputs": [],
98
+ "source": [
99
+ "d = {}\n",
100
+ "for p in [p for p in os.listdir(\".\") if p.endswith(\".pkl\") and p.startswith(\"sd\")]:\n",
101
+ " d[p] = {}\n",
102
+ "\n",
103
+ "pickle.dump(d, open(\"dataset_summary.pkl\", \"wb\"))"
104
+ ]
105
+ }
106
+ ],
107
+ "metadata": {
108
+ "kernelspec": {
109
+ "display_name": "Python 3 (ipykernel)",
110
+ "language": "python",
111
+ "name": "python3"
112
+ },
113
+ "language_info": {
114
+ "codemirror_mode": {
115
+ "name": "ipython",
116
+ "version": 3
117
+ },
118
+ "file_extension": ".py",
119
+ "mimetype": "text/x-python",
120
+ "name": "python",
121
+ "nbconvert_exporter": "python",
122
+ "pygments_lexer": "ipython3",
123
+ "version": "3.9.16"
124
+ }
125
+ },
126
+ "nbformat": 4,
127
+ "nbformat_minor": 5
128
+ }
data/20scenarios/sd_training.tfrecord-00000-of-01000_101d4e5775093d0c.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de3ceb564e944fda1d5b8e2f72428ada2b790d5d57195e6d9bca2cdf761a37f0
3
+ size 338026
data/20scenarios/sd_training.tfrecord-00000-of-01000_10c3969f1eb158d.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:72cae539b2993d4dd668e5029460e2cdbf621dc1734d69ce30b357a156bc8375
3
+ size 535617
data/20scenarios/sd_training.tfrecord-00000-of-01000_1109b0038ed8f25a.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:383b8cdeaefc13c15751f3ff29d28426543c24b6de674e0c5434e39f5d7a8d1f
3
+ size 1060327
data/20scenarios/sd_training.tfrecord-00000-of-01000_116d257e98878d94.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:77209cb776b340b849ff123c7c9019d17495df41bf69fe2cfc765d5c8235fc66
3
+ size 680549
data/20scenarios/sd_training.tfrecord-00000-of-01000_11cdace2c1445900.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf6fc9426a5b209d4f0862ff8046844f40fbf822fcf54a24947f23aa0f140161
3
+ size 426242
data/20scenarios/sd_training.tfrecord-00000-of-01000_12a40f114f1ec5fc.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ddfd0775eedec97b30316beb8c537e3d768928abbd255030403c5bdff59a7f85
3
+ size 837708
data/20scenarios/sd_training.tfrecord-00000-of-01000_149f682e19454efa.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b2f40bfe73b50954987f15a462a4a86659615e416b577b1c1096c8c69ef0d05
3
+ size 667369
data/20scenarios/sd_training.tfrecord-00000-of-01000_16361c8c522cf0e.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:30f1f5e1d3a5863a88b919a79c8c4a75bff2774df99412a092706d5cebc7ffbb
3
+ size 319336
data/20scenarios/sd_training.tfrecord-00000-of-01000_1663adf01133d82.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bacb875cdcd59d4c5ada7c1a2b546a31967de6f15f288d81136d2f3cc12f2413
3
+ size 560466
data/20scenarios/sd_training.tfrecord-00000-of-01000_16b7c9a8ae6e89f1.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5194b6697b1fa3d00113c8ba2f39d50ffb8066b004c8f2f5d15953a550027e7c
3
+ size 456308
data/20scenarios/sd_training.tfrecord-00000-of-01000_16d4f837fb57ff1.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dfce54f80da166a8cbaf5fe167770975dd4dfcc9bba3c29ad228ad112d995690
3
+ size 396202
data/20scenarios/sd_training.tfrecord-00000-of-01000_17100879b93ceb61.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a15fb32533b7472c38ec32c676a9425490de2c95d8dd39b9aa2a6b98f1050512
3
+ size 253625
data/20scenarios/sd_training.tfrecord-00000-of-01000_17fa718cd1251a26.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d126e0c7c1e4afa561739437ad7c7810830c766a45b725fb8cb8a33015029e61
3
+ size 490897
data/20scenarios/sd_training.tfrecord-00000-of-01000_18840a098288507f.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ffbd89555db615f2d20516fedf2dd3689b4487221a82c9f48251e29c30027f31
3
+ size 627272
data/20scenarios/sd_training.tfrecord-00000-of-01000_18f5ce249ee2e949.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e93b1b9cbfd8c541c5c52523aee09f5abd3410d217bf6e402ec217935f5df08
3
+ size 435189
data/20scenarios/sd_training.tfrecord-00000-of-01000_1a7143a44e480ca6.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11ac7db4a7c917c8045c957ee93d2f761de089152fdb9b03554a889d4f1deca3
3
+ size 376448
data/20scenarios/sd_training.tfrecord-00000-of-01000_1a8cc570d620bd31.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5cb753a3c0d7255b5d4333a95c5555b4b159dcc600f9da286787ab37893c70d8
3
+ size 1558365
data/20scenarios/sd_training.tfrecord-00000-of-01000_2a1e44d405a6833f.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ac83f07304e39a8c5e468399b12e3ed034d6855b2921ab094ee425e492bea6d
3
+ size 1319602
data/20scenarios/sd_training.tfrecord-00000-of-01000_8a346109094cd5aa.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c81210cec08ec7bc5813e59bf4a2f8287800c2d626929441a0e7398f7283c509
3
+ size 1316507
data/20scenarios/sd_training.tfrecord-00000-of-01000_c403d5992cab9e0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:26feb91313bee0c89ff1a9660fc352ea687daf04a7ca642f600d2ece17a6e301
3
+ size 1079778
packages.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ ffmpeg
2
+ libgl1
3
+ libglib2.0-0
4
+ libsm6
5
+ libxext6
6
+ libxrender1
7
+ libsdl2-2.0-0
pyproject.toml ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "scenestreamer"
3
+ version = "1.0.0"
4
+ description = "SceneStreamer: Continuous Scenario Generation as Next Token Group Prediction"
5
+ readme = "README.md"
6
+ license = {text = "MIT"}
7
+ authors = [
8
+ {name = "Zhenghao Peng", email = "pzh@berkeley.edu"}
9
+ ]
10
+ requires-python = ">=3.10,<3.12"
11
+ classifiers = [
12
+ "Development Status :: 4 - Beta",
13
+ "License :: OSI Approved :: MIT License",
14
+ "Programming Language :: Python :: 3.10",
15
+ "Programming Language :: Python :: 3.11",
16
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
17
+ ]
18
+
19
+ dependencies = [
20
+ "torch>=2.0.0",
21
+ "torchvision",
22
+ "lightning>=2.0.0",
23
+ "hydra-core",
24
+ "omegaconf",
25
+ "numpy",
26
+ "tqdm",
27
+ "matplotlib",
28
+ "seaborn",
29
+ "Pillow",
30
+ "easydict",
31
+ "wandb",
32
+ "torch_geometric",
33
+ "transformers",
34
+ "tokenizers",
35
+ "huggingface_hub",
36
+ "tensorboardX",
37
+ "pyyaml",
38
+ "scikit-image",
39
+ "chardet",
40
+ "charset-normalizer",
41
+ "tabulate",
42
+ "metadrive-simulator",
43
+ "gradio>=6.9.0",
44
+ "scenarionet @ git+https://github.com/metadriverse/scenarionet.git",
45
+ ]
46
+
47
+ [project.scripts]
48
+ scenestreamer = "scenestreamer.cli:main"
49
+
50
+ [project.optional-dependencies]
51
+ dev = [
52
+ "ruff",
53
+ "pytest",
54
+ ]
55
+ rl = [
56
+ "stable-baselines3>=2.0.0",
57
+ "gymnasium>=0.29.0",
58
+ "ipython",
59
+ ]
60
+ # Note: waymo-open-dataset requires Python 3.10 and specific numpy versions.
61
+ # Install separately: pip install waymo-open-dataset-tf-2-12-0==1.6.4
62
+
63
+ [project.urls]
64
+ Homepage = "https://vail-ucla.github.io/scenestreamer/"
65
+ Repository = "https://github.com/pengzhenghao/scenestreamer"
66
+
67
+ [build-system]
68
+ requires = ["setuptools>=61.0"]
69
+ build-backend = "setuptools.build_meta"
70
+
71
+ [tool.setuptools.packages.find]
72
+ include = ["scenestreamer*"]
73
+
74
+ [tool.ruff]
75
+ line-length = 120
76
+ target-version = "py310"
77
+
78
+ [tool.ruff.lint]
79
+ select = ["E", "F", "W"]
80
+ ignore = ["E501"]
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ .
scenestreamer/__init__.py ADDED
File without changes
scenestreamer/cli.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ import os
6
+ import pathlib
7
+ import runpy
8
+ import sys
9
+ import time
10
+ from dataclasses import dataclass
11
+ from typing import Any
12
+
13
+ import yaml
14
+
15
+
16
+ def _to_plain(obj: Any) -> Any:
17
+ if hasattr(obj, "items"):
18
+ return {k: _to_plain(v) for k, v in obj.items()}
19
+ if isinstance(obj, (list, tuple)):
20
+ return [_to_plain(v) for v in obj]
21
+ return obj
22
+
23
+
24
+ def _to_easydict(obj: Any):
25
+ from easydict import EasyDict
26
+
27
+ if isinstance(obj, dict):
28
+ return EasyDict({k: _to_easydict(v) for k, v in obj.items()})
29
+ if isinstance(obj, list):
30
+ return [_to_easydict(v) for v in obj]
31
+ return obj
32
+
33
+
34
+ def load_yaml_config(path: str | os.PathLike[str]):
35
+ with open(path, "r") as f:
36
+ data = yaml.safe_load(f)
37
+ return _to_easydict(data)
38
+
39
+
40
+ def apply_overrides(cfg, overrides: list[str]) -> None:
41
+ """
42
+ Apply overrides of form KEY=VALUE where KEY is dot-delimited.
43
+ VALUE is parsed using yaml.safe_load (so numbers/bools/lists work).
44
+ """
45
+ for item in overrides:
46
+ if "=" not in item:
47
+ raise ValueError(f"Invalid override (expected KEY=VALUE): {item}")
48
+ key, raw_val = item.split("=", 1)
49
+ value = yaml.safe_load(raw_val)
50
+
51
+ cur = cfg
52
+ parts = key.split(".")
53
+ for p in parts[:-1]:
54
+ if not hasattr(cur, p):
55
+ setattr(cur, p, _to_easydict({}))
56
+ cur = getattr(cur, p)
57
+ setattr(cur, parts[-1], value)
58
+
59
+
60
+ @dataclass(frozen=True)
61
+ class RunPaths:
62
+ run_dir: pathlib.Path
63
+ config_path: pathlib.Path
64
+ metrics_path: pathlib.Path
65
+
66
+
67
+ def make_run_dir(base_dir: str | os.PathLike[str], run_id: str | None) -> RunPaths:
68
+ base = pathlib.Path(base_dir)
69
+ base.mkdir(parents=True, exist_ok=True)
70
+ if run_id is None:
71
+ run_id = time.strftime("%Y%m%d-%H%M%S")
72
+ run_dir = base / run_id
73
+ run_dir.mkdir(parents=True, exist_ok=False)
74
+ return RunPaths(
75
+ run_dir=run_dir,
76
+ config_path=run_dir / "config.yaml",
77
+ metrics_path=run_dir / "metrics.json",
78
+ )
79
+
80
+
81
+ def cmd_preprocess(args: argparse.Namespace) -> None:
82
+ # Prefer failing fast on missing ScenarioNet without importing heavy deps.
83
+ try:
84
+ import scenarionet # noqa: F401
85
+ except ModuleNotFoundError as e:
86
+ raise e
87
+
88
+ from scenestreamer.dataset.dataset import SceneStreamerDataset
89
+
90
+ cfg = load_yaml_config(args.config)
91
+ apply_overrides(cfg, args.set or [])
92
+
93
+ # Paths: prefer CLI args, but allow config overrides.
94
+ if args.train_dir:
95
+ cfg.DATA.TRAINING_DATA_DIR = args.train_dir
96
+ if args.test_dir:
97
+ cfg.DATA.TEST_DATA_DIR = args.test_dir
98
+
99
+ cfg.DATA.USE_CACHE = True
100
+
101
+ run = make_run_dir(args.artifacts_dir, args.run_id)
102
+ with open(run.config_path, "w") as f:
103
+ yaml.safe_dump(_to_plain(cfg), f, sort_keys=False)
104
+
105
+ mode = args.split
106
+ ds = SceneStreamerDataset(cfg, mode)
107
+
108
+ # Iterate to materialize cache files.
109
+ for i in range(len(ds)):
110
+ _ = ds[i]
111
+ if args.limit is not None and (i + 1) >= args.limit:
112
+ break
113
+
114
+ metrics = {
115
+ "status": "ok",
116
+ "mode": mode,
117
+ "train_dir": getattr(cfg.DATA, "TRAINING_DATA_DIR", None),
118
+ "test_dir": getattr(cfg.DATA, "TEST_DATA_DIR", None),
119
+ "limit": args.limit,
120
+ }
121
+ with open(run.metrics_path, "w") as f:
122
+ json.dump(metrics, f, indent=2)
123
+
124
+ print(str(run.run_dir))
125
+
126
+
127
+ def _load_model_from_args(args: argparse.Namespace):
128
+ import torch
129
+
130
+ from scenestreamer.utils import utils
131
+
132
+ device = torch.device(args.device)
133
+ if args.hf_repo:
134
+ return utils.get_model(huggingface_repo=args.hf_repo, huggingface_file=args.hf_file, device=device)
135
+ if args.ckpt:
136
+ return utils.get_model(checkpoint_path=args.ckpt, device=device)
137
+ raise ValueError("Must provide either --hf-repo/--hf-file or --ckpt")
138
+
139
+
140
+ def cmd_table1(args: argparse.Namespace) -> None:
141
+ from scenestreamer.paper.table1_mmd import run_table1_mmd
142
+
143
+ pl_model = _load_model_from_args(args)
144
+ run_dir = run_table1_mmd(
145
+ pl_model=pl_model,
146
+ dataset_dir=args.dataset_dir,
147
+ split=args.split,
148
+ limit=args.limit,
149
+ artifacts_dir=args.artifacts_dir,
150
+ run_id=args.run_id,
151
+ seed=args.seed,
152
+ )
153
+ print(str(run_dir))
154
+
155
+
156
+ def cmd_table2(args: argparse.Namespace) -> None:
157
+ from scenestreamer.paper.table2_motion import run_table2_motion
158
+
159
+ pl_model = _load_model_from_args(args)
160
+ run_dir = run_table2_motion(
161
+ pl_model=pl_model,
162
+ dataset_dir=args.dataset_dir,
163
+ split=args.split,
164
+ mode=args.mode,
165
+ num_modes=args.num_modes,
166
+ limit=args.limit,
167
+ artifacts_dir=args.artifacts_dir,
168
+ run_id=args.run_id,
169
+ seed=args.seed,
170
+ )
171
+ print(str(run_dir))
172
+
173
+
174
+ def cmd_densify_demo(args: argparse.Namespace) -> None:
175
+ from scenestreamer.paper.densify_demo import run_densify_demo
176
+
177
+ pl_model = _load_model_from_args(args)
178
+ run_dir = run_densify_demo(
179
+ pl_model=pl_model,
180
+ dataset_dir=args.dataset_dir,
181
+ split=args.split,
182
+ scenario_index=args.scenario_index,
183
+ max_agents=args.max_agents,
184
+ force_no_end=args.force_no_end,
185
+ artifacts_dir=args.artifacts_dir,
186
+ run_id=args.run_id,
187
+ seed=args.seed,
188
+ )
189
+ print(str(run_dir))
190
+
191
+ def _run_module_as_main(module: str, argv: list[str]) -> None:
192
+ old_argv = sys.argv[:]
193
+ try:
194
+ sys.argv = [module] + argv
195
+ runpy.run_module(module, run_name="__main__")
196
+ finally:
197
+ sys.argv = old_argv
198
+
199
+
200
+ def cmd_table3_train(args: argparse.Namespace) -> None:
201
+ _run_module_as_main("scenestreamer.rl_train.train.train_td3", args.args)
202
+
203
+
204
+ def cmd_table3_eval(args: argparse.Namespace) -> None:
205
+ _run_module_as_main("scenestreamer.rl_train.train.eval_policy", args.args)
206
+
207
+
208
+ def build_parser() -> argparse.ArgumentParser:
209
+ parser = argparse.ArgumentParser(prog="scenestreamer", description="SceneStreamer paper reproduction CLI")
210
+ sub = parser.add_subparsers(dest="cmd", required=True)
211
+
212
+ def add_run_args(p: argparse.ArgumentParser) -> None:
213
+ p.add_argument("--artifacts-dir", default="artifacts", help="Directory to write run artifacts")
214
+ p.add_argument("--run-id", default=None, help="Run ID (default: timestamp)")
215
+ p.add_argument("--seed", type=int, default=0)
216
+
217
+ def add_model_args(p: argparse.ArgumentParser) -> None:
218
+ p.add_argument("--device", default="cuda", help="torch device string, e.g. cuda or cpu")
219
+ p.add_argument("--ckpt", default=None, help="Path to a .ckpt checkpoint")
220
+ p.add_argument("--hf-repo", default=None, help="HuggingFace repo id, e.g. user/repo")
221
+ p.add_argument("--hf-file", default=None, help="HuggingFace filename, e.g. model.ckpt")
222
+
223
+ # preprocess
224
+ p = sub.add_parser("preprocess", help="Preprocess ScenarioNet SD dataset and build cache")
225
+ add_run_args(p)
226
+ p.add_argument("--config", default="cfgs/motion_default.yaml")
227
+ p.add_argument("--set", action="append", default=[], help="Override config KEY=VALUE (repeatable)")
228
+ p.add_argument("--train-dir", default=None)
229
+ p.add_argument("--test-dir", default=None)
230
+ p.add_argument("--split", choices=["training", "test"], default="training")
231
+ p.add_argument("--limit", type=int, default=None)
232
+ p.set_defaults(func=cmd_preprocess)
233
+
234
+ # table1
235
+ p = sub.add_parser("table1", help="Table 1: initial state MMD (strict + relaxed)")
236
+ add_run_args(p)
237
+ add_model_args(p)
238
+ p.add_argument("--dataset-dir", required=True)
239
+ p.add_argument("--split", choices=["training", "test"], default="test")
240
+ p.add_argument("--limit", type=int, default=None)
241
+ p.set_defaults(func=cmd_table1)
242
+
243
+ # table2
244
+ p = sub.add_parser("table2", help="Table 2: motion prediction (ADE/FDE + ADD/FDD)")
245
+ add_run_args(p)
246
+ add_model_args(p)
247
+ p.add_argument("--dataset-dir", required=True)
248
+ p.add_argument("--split", choices=["training", "test"], default="test")
249
+ p.add_argument("--mode", choices=["motion", "full"], default="motion")
250
+ p.add_argument("--num-modes", type=int, default=6)
251
+ p.add_argument("--limit", type=int, default=None)
252
+ p.set_defaults(func=cmd_table2)
253
+
254
+ # demo
255
+ p = sub.add_parser("densify-demo", help="Qualitative densification demo (generate to max agents)")
256
+ add_run_args(p)
257
+ add_model_args(p)
258
+ p.add_argument("--dataset-dir", required=True)
259
+ p.add_argument("--split", choices=["training", "test"], default="test")
260
+ p.add_argument("--scenario-index", type=int, default=0)
261
+ p.add_argument("--max-agents", type=int, default=128)
262
+ p.add_argument("--force-no-end", action="store_true", help="Disable end token so it keeps generating agents")
263
+ p.set_defaults(func=cmd_densify_demo)
264
+
265
+ p = sub.add_parser("table3-train", help="Table 3: RL training (pass-through to train_td3.py)")
266
+ p.add_argument("args", nargs=argparse.REMAINDER, help="Arguments forwarded to train_td3.py")
267
+ p.set_defaults(func=cmd_table3_train)
268
+
269
+ p = sub.add_parser("table3-eval", help="Table 3: RL evaluation (pass-through to eval_policy.py)")
270
+ p.add_argument("args", nargs=argparse.REMAINDER, help="Arguments forwarded to eval_policy.py")
271
+ p.set_defaults(func=cmd_table3_eval)
272
+
273
+ return parser
274
+
275
+
276
+ def main(argv: list[str] | None = None) -> None:
277
+ parser = build_parser()
278
+ args = parser.parse_args(argv)
279
+ try:
280
+ args.func(args)
281
+ except ModuleNotFoundError as e:
282
+ # Most common in a fresh environment: scenarionet / waymo-open-dataset missing.
283
+ msg = str(e)
284
+ if "scenarionet" in msg:
285
+ raise SystemExit(
286
+ "Missing dependency 'scenarionet'. Install it via:\n"
287
+ " pip install git+https://github.com/metadriverse/scenarionet.git\n"
288
+ ) from e
289
+ raise
290
+
291
+
292
+ if __name__ == "__main__":
293
+ main()
scenestreamer/clustering.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ nohup python clustering.py --data 3 > clustering_obj3_all_nomin.log 2>&1 &
2
+ nohup python clustering.py --data 2 > clustering_obj2_all_nomin.log 2>&1 &
3
+ nohup python clustering.py --data 1 > clustering_obj1_all_nomin.log 2>&1 &
4
+
5
+ #nohup python clustering.py --data 3 --min_scale 0.5 > clustering_obj3_all.log 2>&1 &
6
+ #nohup python clustering.py --data 2 --min_scale 0.5 > clustering_obj2_all.log 2>&1 &
7
+ #nohup python clustering.py --data 1 --min_scale 0.5 > clustering_obj1_all.log 2>&1 &
scenestreamer/dataset/__init__.py ADDED
File without changes
scenestreamer/dataset/constants.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Define a lot of constants. It should be totally removed as most of them should be defined by MetaDrive / ScenarioNet.
3
+ """
4
+ from metadrive.scenario.scenario_description import MetaDriveType
5
+
6
+ # NUM_TYPES = 3
7
+ NUM_TYPES = 5
8
+
9
+ MAP_FEATURE_STATE_DIM = 27
10
+ TRAFFIC_LIGHT_STATE_DIM = 7
11
+
12
+ AGENT_STATE_DIM = 16
13
+
14
+ # ACTOR_PREDICT_DIM = 6 + 2 + 4 + 5 # 3 for position, 1 for heading, 2 for velocity, 5 for types
15
+ TRAFFIC_LIGHT_PREDICT_DIM = 9 # 9 original possible state
16
+
17
+ # TODO(pzh): Do we have to do the normalization? Shouldn't the layer norm solve this?
18
+ # POSITION_XY_RANGE = 100.
19
+ # LOCAL_POSITION_XY_RANGE = 5.
20
+ # HEADING_RANGE = np.pi
21
+ # VELOCITY_XY_RANGE = 10.
22
+ # SIZE_RANGE = 5.
23
+ # MAP_VECTOR_XY_RANGE = 50.
24
+
25
+ # TODO(pzh): Consider remove this.
26
+ object_type_to_int = {
27
+ MetaDriveType.UNSET: 0,
28
+ MetaDriveType.VEHICLE: 1,
29
+ MetaDriveType.PEDESTRIAN: 2,
30
+ MetaDriveType.CYCLIST: 3,
31
+ MetaDriveType.OTHER: 4
32
+ }
33
+
34
+ # TODO(pzh): Consider remove this.
35
+ object_int_to_type = {
36
+ -1: MetaDriveType.UNSET,
37
+ 0: MetaDriveType.UNSET,
38
+ 1: MetaDriveType.VEHICLE,
39
+ 2: MetaDriveType.PEDESTRIAN,
40
+ 3: MetaDriveType.CYCLIST,
41
+ 4: MetaDriveType.OTHER
42
+ }
43
+
44
+ HEADING_PLACEHOLDER = -100 # For the object that has no heading, set this.
scenestreamer/dataset/datamodule.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is a wrapper to wrap our dataset as a lightning datamodule.
3
+ """
4
+ import lightning.pytorch as pl
5
+ from torch.utils.data import DataLoader
6
+
7
+ from scenestreamer.dataset import dataset
8
+
9
+
10
+ class SceneStreamerDataModule(pl.LightningDataModule):
11
+ def __init__(
12
+ self, config, train_batch_size, train_num_workers, train_prefetch_factor, val_batch_size, val_num_workers,
13
+ val_prefetch_factor
14
+ ):
15
+ super().__init__()
16
+ self.config = config
17
+ self.train_batch_size = train_batch_size
18
+ self.train_num_workers = train_num_workers
19
+ self.train_prefetch_factor = train_prefetch_factor
20
+ self.val_batch_size = val_batch_size
21
+ self.val_num_workers = val_num_workers
22
+ self.val_prefetch_factor = val_prefetch_factor
23
+
24
+ def setup(self, stage: str):
25
+ self.train_dataset = dataset.SceneStreamerDataset(config=self.config, mode="training")
26
+ self.val_dataset = dataset.SceneStreamerDataset(config=self.config, mode="test")
27
+
28
+ def train_dataloader(self):
29
+ return DataLoader(
30
+ self.train_dataset,
31
+ batch_size=self.train_batch_size,
32
+ pin_memory=True,
33
+ num_workers=self.train_num_workers,
34
+ shuffle=True,
35
+ persistent_workers=True if self.train_num_workers > 0 else False,
36
+ collate_fn=self.train_dataset.collate_batch,
37
+ prefetch_factor=self.train_prefetch_factor if self.train_num_workers > 0 else None,
38
+ )
39
+
40
+ def val_dataloader(self):
41
+ return DataLoader(
42
+ self.val_dataset,
43
+ batch_size=self.val_batch_size,
44
+ pin_memory=True,
45
+ num_workers=self.val_num_workers,
46
+ shuffle=False,
47
+ collate_fn=self.val_dataset.collate_batch,
48
+ prefetch_factor=self.val_prefetch_factor if self.val_num_workers > 0 else None,
49
+ )
scenestreamer/dataset/dataset.py ADDED
@@ -0,0 +1,630 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Create a pytorch dataset class for loading scenario files and padding data entries.
3
+ """
4
+ import copy
5
+ import json
6
+ import os
7
+ import pathlib
8
+ import pickle
9
+
10
+ try:
11
+ import hydra
12
+ except ModuleNotFoundError: # optional for core library usage
13
+ hydra = None
14
+ import numpy as np
15
+ from scenarionet import read_dataset_summary, read_scenario
16
+ from torch.utils.data import Dataset
17
+
18
+ from scenestreamer.dataset.preprocessor import preprocess_scenario_description
19
+ from scenestreamer.utils import global_config
20
+ from scenestreamer.utils import utils
21
+
22
+ # import lmdb
23
+
24
+ REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
25
+ QA_DATASET_MAPPING = {}
26
+ ADV_INFO_DICT = {}
27
+
28
+
29
+ class NoMapFeatureError(Exception):
30
+ pass
31
+
32
+
33
+ class LMDBDatasetReader:
34
+ def __init__(self, base_path):
35
+ self.base_path = base_path
36
+ # Load the lookup table that maps sample keys to LMDB file names
37
+ # Search recursively all subfolder to find lookup.json
38
+ self.lookup = {}
39
+ for root, dirs, files in os.walk(self.base_path):
40
+ if "lookup.json" in files:
41
+ lookup_path = os.path.join(root, "lookup.json")
42
+ with open(lookup_path, "r") as f:
43
+ lookup = json.load(f)
44
+ self.lookup.update(lookup)
45
+ self.lmdb_cache = {} # Cache for open LMDB environments
46
+
47
+
48
+ # def _get_lmdb_env(self, lmdb_name):
49
+ # """Fetches or opens an LMDB environment for reading."""
50
+ # if lmdb_name not in self.lmdb_cache:
51
+ # self.lmdb_cache[lmdb_name] = lmdb.open(lmdb_name, readonly=True)
52
+ # return self.lmdb_cache[lmdb_name]
53
+
54
+ # def load_sample(self, key):
55
+ # """Loads a preprocessed sample by key."""
56
+ # lmdb_name = self.lookup.get(key)
57
+ # if lmdb_name is None:
58
+ # raise KeyError(f"Sample {key} not found in lookup.")
59
+ # env = self._get_lmdb_env(lmdb_name)
60
+ # with env.begin() as txn:
61
+ # npz_bytes = txn.get(key.encode('ascii'))
62
+ # if npz_bytes:
63
+ # with io.BytesIO(npz_bytes) as buffer:
64
+ # data = np.load(buffer, allow_pickle=True)
65
+ # return {name: data[name] for name in data.files} # Return data as a dictionary
66
+ # return None
67
+
68
+ # def close(self):
69
+ # """Closes all open LMDB environments."""
70
+ # for env in self.lmdb_cache.values():
71
+ # env.close()
72
+
73
+
74
+ def process_QA_text_label(QA_dict):
75
+ # TODO: do we need to form label for each individual agent? Rightnow it is just a single label
76
+ labels = {}
77
+
78
+ env_a = QA_dict['env_a']
79
+ labels['env'] = ' '.join(env_a)
80
+
81
+ ego_a = QA_dict['ego_a']
82
+ labels['ego'] = ' '.join(ego_a)
83
+
84
+ int_a = QA_dict['int_a']
85
+ labels['int'] = ' '.join(int_a)
86
+
87
+ return labels
88
+
89
+
90
+ def get_file_paths(directory):
91
+ file_paths = []
92
+ # Traverse the directory
93
+ for root, dirs, files in os.walk(directory):
94
+ for file in files:
95
+ # Get the full path and add it to the list
96
+ full_path = os.path.join(root, file)
97
+ file_paths.append(full_path)
98
+ return
99
+
100
+
101
+ def load_json_to_dict(file_path):
102
+ """
103
+ Load a JSON file into a Python dictionary.
104
+
105
+ :param file_path: Path to the JSON file
106
+ :return: Dictionary containing the JSON data
107
+ """
108
+ try:
109
+ with open(file_path, 'r') as file:
110
+ data = json.load(file)
111
+ return data
112
+ except FileNotFoundError:
113
+ print(f"Error: The file at {file_path} was not found.")
114
+ except json.JSONDecodeError:
115
+ print(f"Error: The file at {file_path} is not a valid JSON file.")
116
+ except Exception as e:
117
+ print(f"An unexpected error occurred: {e}")
118
+
119
+ return None
120
+
121
+
122
+ class SceneStreamerDataset(Dataset):
123
+ """
124
+ SceneStreamer dataset class. Returns data_dict for each scenario.
125
+ Init args:
126
+ mode: "training" or "test".
127
+ config:
128
+ - model: Details about the model architecture.
129
+ - data: Data directories, sample intervals, number of agents, etc.
130
+ - evaluation: predict_all_agents, delete_eval_result (TODO: Add ScenarioDescription passthrough as a flag in the config.)
131
+ - optimization: Training hyperparameters.
132
+ - preprocessing: Max number of agents, map features, traffic lights, padding, etc.
133
+ - root_dir: Self-explanatory.
134
+ - sampling: Inference sampling parameters.
135
+ - tokenization: The part of the config passed to the tokenizer.
136
+ """
137
+ def __init__(self, config, mode):
138
+ super().__init__()
139
+ self.mode = mode
140
+ self.config = config
141
+ dataset_cfg = self.config.DATA
142
+
143
+ self.max_map_features = config.PREPROCESSING.MAX_MAP_FEATURES
144
+ self.max_vectors_per_map_feature = config.PREPROCESSING.MAX_VECTORS
145
+ self.max_agents = config.PREPROCESSING.MAX_AGENTS
146
+ self.max_traffic_lights = config.PREPROCESSING.MAX_TRAFFIC_LIGHTS
147
+ self.padding_to_max = config.PREPROCESSING.PADDING_TO_MAX
148
+
149
+ # We are expecting the data_dir to be either an absolute path or a relative path w.r.t. the repo root.
150
+ if mode == "training":
151
+ self.data_dir = global_config.ROOT_DIR / dataset_cfg.TRAINING_DATA_DIR
152
+ elif mode == "test":
153
+ self.data_dir = global_config.ROOT_DIR / dataset_cfg.TEST_DATA_DIR
154
+ else:
155
+ raise ValueError(f"Unknown mode {mode}.")
156
+
157
+ # summary_dict: A dictionary of .pkl filenames to ingest. Filenames (keys) are mapped to metadata objects.
158
+ # summary_list: Keys of summary_dict, in order of ingestion.
159
+ # mapping: A dict mapping scenario IDs to the folder that hosts their files.
160
+ summary_dict, summary_list, mapping = read_dataset_summary(self.data_dir)
161
+
162
+ # We might want to use a subset of scenarios.
163
+ if self.mode == "training":
164
+ interval = dataset_cfg.SAMPLE_INTERVAL_TRAINING
165
+ elif self.mode == "test":
166
+ interval = dataset_cfg.SAMPLE_INTERVAL_TEST
167
+ else:
168
+ raise ValueError(f"Unknown mode {self.mode}.")
169
+
170
+ if "SD_PASSTHROUGH" in config.DATA:
171
+ self.return_scenario_description = config.DATA["SD_PASSTHROUGH"]
172
+ else: # Default to False.
173
+ self.return_scenario_description = False
174
+
175
+ summary_list = summary_list[::interval]
176
+ # self.data_summary_dict = {k: summary_dict[k] for k in summary_list}
177
+ self.data_mapping = {k: mapping[k] for k in summary_list}
178
+ self.length = len(summary_list)
179
+ self.use_cache_logged = False
180
+
181
+ if self.config.BACKWARD_PREDICTION and self.mode == "training":
182
+ self.real_length = self.length
183
+ self.length = self.length * 2
184
+
185
+ # Convert each string to sequence of codepoints (integer),
186
+ # and then pack them into a numpy array.
187
+ # NOTE(pzh): I forgot why I wrote this. Seems like some issues in multiprocessing.
188
+
189
+ # seqs: A list of np.arrays, each representing the ascii values of a string.
190
+ seqs = [utils.string_to_sequence(s) for s in summary_list]
191
+
192
+ # strings_v: ascii values of all strings, concatenated.
193
+ # strings_o: offsets of each string in strings_v.
194
+ if len(seqs) == 0:
195
+ raise ValueError("No scenarios found in the dataset: {}".format(self.data_dir))
196
+ self.strings_v, self.strings_o = utils.pack_sequences(seqs)
197
+
198
+ # if self.config.DATA.USE_LMDB and self.mode == "training":
199
+ # cache_folder = pathlib.Path(self.data_dir) / "cache"
200
+ # assert cache_folder.is_dir()
201
+ # self.reader = LMDBDatasetReader(cache_folder) # LMDB Reader to load samples
202
+
203
+ from scenestreamer.tokenization import get_tokenizer
204
+ self.tokenizer = get_tokenizer(config=self.config)
205
+
206
+ def __len__(self):
207
+ return self.length
208
+
209
+ def __getitem__(self, index):
210
+ # Unpack the stored codepoints at the correct index into a filename string.
211
+
212
+ use_backward_prediction = False
213
+ if self.config.BACKWARD_PREDICTION and self.mode == "training":
214
+ if index >= self.real_length:
215
+ index = index - self.real_length
216
+ use_backward_prediction = True
217
+
218
+ seq = utils.unpack_sequence(self.strings_v, self.strings_o, index)
219
+ string = utils.sequence_to_string(seq)
220
+ file_name = string
221
+
222
+ try:
223
+ data_dict = self.create_scene_level_data(file_name, index, use_backward_prediction)
224
+ except NoMapFeatureError:
225
+ # This is workaround for Waymo test set where some scenarios do not have map features.
226
+ return self.__getitem__(index + 1)
227
+
228
+ # If self.return_scenario_description is true, data_dict has an extra key [raw_scenario_description] that contains the ScenarioDescription object.
229
+ return data_dict
230
+
231
+ def create_scene_level_data(self, file_name, index, use_backward_prediction=False):
232
+ """
233
+ Reads a scenario file and preprocesses it.
234
+ """
235
+ assert not self.config.DATA.USE_LMDB, "LMDB is not supported."
236
+ try:
237
+ # scenario: A ScenarioDescription instance.
238
+ cache = None
239
+ scenario = None
240
+ cache_path = None
241
+
242
+ if self.config.DATA.USE_CACHE:
243
+ cache_folder = pathlib.Path(self.data_dir) / "cache"
244
+ if cache_folder.is_dir() is False:
245
+ cache_folder.mkdir(exist_ok=True)
246
+
247
+ cache_path = pathlib.Path(self.data_dir) / "cache" / file_name
248
+ if cache_path.is_file():
249
+
250
+ try:
251
+ with open(cache_path, "rb") as f:
252
+ cache = pickle.load(f)
253
+
254
+ if self.use_cache_logged is False:
255
+ print("=====================================")
256
+ print("=====================================")
257
+ print("\t*** WARNING ***")
258
+ print("\tYou are using cache files!!!")
259
+ print("\tIn folder: ", cache_folder)
260
+ print("\tThere are ", len(list(cache_folder.glob("*"))), " cache files!!!")
261
+ print("=====================================")
262
+ print("=====================================")
263
+
264
+ self.use_cache_logged = True
265
+
266
+ return cache
267
+ except EOFError as e:
268
+ print(f"Error in reading cache file: {cache_path=}")
269
+
270
+ scenario = read_scenario(
271
+ dataset_path=self.data_dir, mapping=self.data_mapping, scenario_file_name=file_name
272
+ )
273
+
274
+ else:
275
+ scenario = read_scenario(
276
+ dataset_path=self.data_dir, mapping=self.data_mapping, scenario_file_name=file_name
277
+ )
278
+ # print("Cannot find cache file: ", cache_path, "Creating one.")
279
+
280
+ else:
281
+ # if self.config.DATA.USE_LMDB and self.mode == "training":
282
+ # cache = self.reader.load_sample(file_name)
283
+ # else:
284
+ scenario = read_scenario(
285
+ dataset_path=self.data_dir, mapping=self.data_mapping, scenario_file_name=file_name
286
+ )
287
+
288
+ except EOFError as e:
289
+ print(f"{self.data_dir=}, {self.data_mapping=}, {file_name=}")
290
+ raise e
291
+ assert self.mode in ["training", "test"], self.mode
292
+ ret = {}
293
+
294
+ if len(scenario["map_features"]) == 0:
295
+ raise NoMapFeatureError
296
+
297
+ if self.return_scenario_description:
298
+ ret["raw_scenario_description"] = copy.deepcopy(scenario)
299
+
300
+ # TODO: Remove error handling after debugging.
301
+ try:
302
+ preprocessed_scenario_description = preprocess_scenario_description(
303
+ scenario=scenario,
304
+ # cache=cache,
305
+ config=copy.deepcopy(self.config),
306
+ in_evaluation=self.mode != "training",
307
+ keep_all_data=self.config.PREPROCESSING.get("keep_all_data", False),
308
+ backward_prediction=use_backward_prediction,
309
+ tokenizer=self.tokenizer,
310
+ # cache_path=cache_path,
311
+ )
312
+ preprocessed_scenario_description["file_name"] = file_name
313
+ except Exception as e:
314
+ print(f"Error in preprocessing {file_name=}, {index=}, {scenario['id']=}")
315
+ # Ensure that the exception is not swallowed by adding this.
316
+ raise RuntimeError(
317
+ f"{file_name=}, {index=}, {scenario['id']=}. Error in create_scene_level_data: {e}"
318
+ ) from e
319
+
320
+ ret.update(preprocessed_scenario_description)
321
+ ret.update({"metadata/scenario_id": scenario['id']})
322
+
323
+ if cache_path is not None:
324
+ with open(cache_path, "wb") as f:
325
+ pickle.dump(ret, f)
326
+ # print("Writing cache file: ", cache_path)
327
+
328
+ return ret
329
+
330
+ def collate_batch(self, batch_list):
331
+ """
332
+ Output format:
333
+
334
+ agent_feature: [B, T, #agents, D]
335
+ agent_feature_position: [B, T, #agents, 3]
336
+ map_feature: [B, T, #mapfeat, #points, D]
337
+ map_feature_valid_mask: [B, T, #mapfeat, #points]
338
+ map_feature_position: [B, T, #mapfeat, 3]
339
+ """
340
+ data_dict_sample = batch_list[0]
341
+
342
+ num_map_feat, num_points, _ = data_dict_sample["encoder/map_feature"].shape
343
+
344
+ data_dict = {}
345
+ object_keys = [
346
+ "raw_scenario_description",
347
+ "encoder/track_name",
348
+ "decoder/track_name",
349
+ "eval/track_name",
350
+ # "scenario_id",
351
+ # "in_evaluation"
352
+ ] # Keys exempt from padding and tensor conversion.
353
+
354
+ for k in set(data_dict_sample.keys()):
355
+ if k not in object_keys:
356
+ if not isinstance(data_dict_sample[k], np.ndarray):
357
+ assert isinstance(data_dict_sample[k], (int, float, bool, str)), (k, type(data_dict_sample[k]))
358
+ if isinstance(data_dict_sample[k], str):
359
+ data_dict[k] = np.array([b[k] for b in batch_list])
360
+ else:
361
+ data_dict[k] = utils.numpy_to_torch(np.array([b[k] for b in batch_list]))
362
+ continue
363
+ # else:
364
+ # if batch_list[0][k].dtype == np.object:
365
+ # data_dict[k] = [b[k] for b in batch_list]
366
+ # continue
367
+
368
+ val_list = [utils.numpy_to_torch(b[k]) for b in batch_list]
369
+
370
+ # Map features that have vectors' information
371
+ if k in [
372
+ "encoder/map_feature",
373
+ "vis/map_feature",
374
+ "raw/map_feature",
375
+ "encoder/map_feature_valid_mask",
376
+ ]:
377
+ data_dict[k] = utils.padding_1st_and_2nd_dim(
378
+ val_list,
379
+ max_1st_dim=self.max_map_features if self.padding_to_max else None,
380
+ max_2nd_dim=self.max_vectors_per_map_feature if self.padding_to_max else None
381
+ )
382
+
383
+ # Map features that have aggregated info from vectors
384
+ elif k in [
385
+ "encoder/map_heading",
386
+ "encoder/map_position",
387
+ "encoder/map_valid_mask",
388
+ ]:
389
+ data_dict[k] = utils.padding_1st_dim(
390
+ val_list, max_1st_dim=self.max_map_features if self.padding_to_max else None
391
+ )
392
+
393
+ # Traffic light features that have temporal dim
394
+ elif k in [
395
+ "encoder/traffic_light_feature",
396
+ "encoder/traffic_light_state",
397
+ "encoder/traffic_light_valid_mask",
398
+ ]:
399
+
400
+ if self.config.PREPROCESSING.REMOVE_TRAFFIC_LIGHT_STATE:
401
+ data_dict[k] = utils.padding_1st_dim(
402
+ val_list, max_1st_dim=self.max_traffic_lights if self.padding_to_max else None
403
+ )
404
+ else:
405
+ data_dict[k] = utils.padding_1st_and_2nd_dim(
406
+ val_list, max_2nd_dim=self.max_traffic_lights if self.padding_to_max else None
407
+ )
408
+
409
+ # Traffic light features that do not have temporal dim
410
+ elif k in [
411
+ "encoder/traffic_light_position",
412
+ "encoder/traffic_light_heading",
413
+ "encoder/traffic_light_map_id",
414
+ ]:
415
+ data_dict[k] = utils.padding_1st_dim(
416
+ val_list, max_1st_dim=self.max_traffic_lights if self.padding_to_max else None
417
+ )
418
+
419
+ # Agent features
420
+ elif k in [
421
+ "encoder/agent_feature",
422
+ "encoder/agent_position",
423
+ "encoder/agent_valid_mask",
424
+ "encoder/agent_heading",
425
+ "encoder/agent_velocity",
426
+ "decoder/modeled_agent_position",
427
+ "decoder/modeled_agent_heading",
428
+ "decoder/modeled_agent_velocity",
429
+ "decoder/modeled_agent_delta",
430
+ ]:
431
+ data_dict[k] = utils.padding_1st_and_2nd_dim(
432
+ val_list, max_2nd_dim=self.max_agents if self.padding_to_max else None
433
+ )
434
+
435
+ # Other data that does not pass the model or does not need regular shapes
436
+ elif k in [
437
+ # "encoder/modeled_agent_id",
438
+ # "action_label/labeled_agent_id",
439
+ "metadata/map_center", # "decoder/input_step",
440
+ # "decoder/input_intra_step",
441
+ "encoder/current_agent_heading",
442
+ "decoder/current_agent_heading",
443
+ "encoder/current_agent_shape",
444
+ "decoder/current_agent_shape",
445
+ "eval/current_agent_heading",
446
+ "encoder/current_agent_valid_mask",
447
+ "decoder/current_agent_valid_mask",
448
+ "eval/current_agent_valid_mask",
449
+ # "decoder/current_agent_valid_mask", #
450
+ # "decoder/modeled_agent_indices",
451
+ # For gen model:
452
+ # "decoder/input_token_valid_mask",
453
+ # "decoder/should_predict",
454
+ # "decoder/is_gt",
455
+ # "eval/should_predict_motion",
456
+ ]:
457
+ data_dict[k] = utils.padding_1st_dim(val_list)
458
+
459
+ elif k in [
460
+ "decoder/input_action_valid_mask",
461
+ "encoder/current_agent_position",
462
+ "decoder/current_agent_position",
463
+ "encoder/current_agent_velocity",
464
+ "decoder/current_agent_velocity",
465
+ "decoder/target_action_valid_mask",
466
+ #"decoder/future_agent_position",
467
+ #"decoder/future_agent_heading",
468
+ #"decoder/future_agent_valid_mask",
469
+ #"decoder/future_agent_velocity",
470
+ #"encoder/future_agent_position",
471
+ #"encoder/future_agent_heading",
472
+ #"encoder/future_agent_valid_mask",
473
+ #"encoder/future_agent_velocity",
474
+ "decoder/agent_position",
475
+ "decoder/agent_heading",
476
+ "decoder/agent_velocity",
477
+ "decoder/agent_valid_mask",
478
+ "eval/agent_velocity",
479
+ "eval/agent_heading",
480
+ "eval/agent_position",
481
+ "eval/agent_valid_mask",
482
+ "encoder/agent_shape",
483
+ "decoder/agent_shape",
484
+ "eval/agent_shape", # "decoder/target_valid_mask",
485
+ "decoder/input_agent_motion",
486
+ "decoder/target_agent_motion",
487
+ "decoder/dest_map_index_valid_mask",
488
+ ]:
489
+ data_dict[k] = utils.padding_1st_and_2nd_dim(val_list)
490
+
491
+ elif k in [
492
+ "encoder/agent_type",
493
+ "decoder/agent_type",
494
+ "encoder/modeled_agent_type",
495
+ "eval/agent_type", # "eval/raw_agent_name",
496
+ "encoder/object_of_interest_name",
497
+ "decoder/object_of_interest_name",
498
+ "metadata/sdc_name", # "eval/modeled_agent_id",
499
+ "encoder/object_of_interest_id",
500
+ "decoder/object_of_interest_id",
501
+ "encoder/modeled_agent_id", # "decoder/modeled_agent_id",
502
+ "encoder/agent_id",
503
+ "decoder/agent_id",
504
+ "decoder/labeled_agent_id",
505
+ "decoder/label_turning",
506
+ "decoder/label_acceleration",
507
+ "decoder/label_safety",
508
+ # For gen model:
509
+ # "decoder/input_token_id",
510
+ # "decoder/causal_mask_offset",
511
+ ]:
512
+ data_dict[k] = utils.padding_1st_dim(val_list, fill=-1)
513
+
514
+ elif k in [
515
+ "decoder/dest_map_index",
516
+ "decoder/dest_map_index_gt",
517
+ ]:
518
+ data_dict[k] = utils.padding_1st_and_2nd_dim(val_list, fill=-1)
519
+
520
+ elif k in [
521
+ "decoder/input_action",
522
+ "decoder/target_action",
523
+ "decoder/input_action_for_trafficgen",
524
+
525
+ "decoder/current_agent_shape_for_trafficgen",
526
+ "decoder/modeled_agent_heading_for_trafficgen",
527
+ "decoder/modeled_agent_position_for_trafficgen",
528
+ "decoder/modeled_agent_velocity_for_trafficgen",
529
+ "decoder/input_action_valid_mask_for_trafficgen",
530
+ "decoder/modeled_agent_delta_for_trafficgen",
531
+ "decoder/input_action_feature_for_trafficgen",
532
+ "decoder/target_offset_for_trafficgen",
533
+ "decoder/input_offset_for_trafficgen",
534
+ "decoder/agent_id_for_trafficgen",
535
+ "decoder/trafficgen_position",
536
+ "decoder/trafficgen_heading",
537
+ "decoder/agent_type_for_trafficgen",
538
+ ]:
539
+ data_dict[k] = utils.padding_all_dims(val_list, fill=-1)
540
+
541
+ elif k in object_keys:
542
+ # Passthrough: Have the data_dict[object] contain a list of objects.
543
+ data_dict[k] = [b[k] for b in batch_list]
544
+
545
+ elif k in [
546
+ "encoder/sdc_index",
547
+ ]:
548
+ pass
549
+
550
+ else:
551
+ raise ValueError("Unknown key: {}".format(k))
552
+
553
+ return data_dict
554
+
555
+
556
+ if hydra is not None:
557
+ @hydra.main(version_base=None, config_path=str(REPO_ROOT / "cfgs"), config_name="1009_safety_action_debug.yaml")
558
+ def debug(config):
559
+ test_dataset = SceneStreamerDataset(config, "training")
560
+ ddd = iter(test_dataset)
561
+ count = 0
562
+ buggy_count = 0
563
+ while True:
564
+ if count == 3:
565
+ return
566
+ try:
567
+ data = next(ddd)
568
+ count += 1
569
+
570
+ assert data["decoder/label_safety"][data["decoder/labeled_agent_id"]].sum() > 1
571
+
572
+ except StopIteration:
573
+ break
574
+
575
+ except AssertionError:
576
+ print("ni collision")
577
+ buggy_count += 1
578
+ print("scenario_id", data["scenario_id"])
579
+ print("data['decoder/label_safety']", data["decoder/label_safety"])
580
+ print("data['decoder/labeled_agent_id']", data["decoder/labeled_agent_id"])
581
+ print("track_name", data["decoder/track_name"][data["decoder/labeled_agent_id"]])
582
+
583
+ print("buggy_count:", buggy_count)
584
+ print("count", count)
585
+ print("End")
586
+
587
+ @hydra.main(version_base=None, config_path=str(REPO_ROOT / "cfgs"), config_name="motion_default.yaml")
588
+ def read_traffic_light_state(config):
589
+ test_dataset = SceneStreamerDataset(config, "training")
590
+
591
+ total_tl = 0
592
+ total_green = 0
593
+ total_yellow = 0
594
+ total_red = 0
595
+ total_unknown = 0
596
+ total_mix = 0
597
+ import tqdm
598
+
599
+ for data in tqdm.tqdm(test_dataset):
600
+ tl = data["encoder/traffic_light_feature"]
601
+ mask = data["encoder/traffic_light_valid_mask"]
602
+
603
+ for i in range(tl.shape[1]):
604
+ if mask[:, i].any():
605
+ is_green = tl[:, i, 3].astype(bool).any()
606
+ is_yellow = tl[:, i, 4].astype(bool).any()
607
+ is_red = tl[:, i, 5].astype(bool).any()
608
+ is_unknown = tl[:, i, 6].astype(bool).any()
609
+
610
+ total_tl += 1
611
+ total_green += is_green
612
+ total_yellow += is_yellow
613
+ total_red += is_red
614
+ total_unknown += is_unknown
615
+ total_mix += (is_green and is_yellow) or (is_green and is_red) or (is_yellow and is_red)
616
+
617
+ print("total_tl:", total_tl)
618
+ print("total_green: {}\t{:.4f}".format(total_green, total_green / total_tl))
619
+ print("total_yellow: {}\t{:.4f}".format(total_yellow, total_yellow / total_tl))
620
+ print("total_red: {}\t{:.4f}".format(total_red, total_red / total_tl))
621
+ print("total_unknown: {}\t{:.4f}".format(total_unknown, total_unknown / total_tl))
622
+ print("total_mix: {}\t{:.4f}".format(total_mix, total_mix / total_tl))
623
+ else:
624
+ debug = None
625
+ read_traffic_light_state = None
626
+
627
+
628
+ if __name__ == '__main__':
629
+ # debug()
630
+ read_traffic_light_state()
scenestreamer/dataset/make_lmdb.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Only the TRAINING_DATA_DIR will be used in the code below.
3
+ Usage:
4
+
5
+ python -m scenestreamer.dataset.make_lmdb \
6
+ --config-name="1024_gpt" DATA.TEST_DATA_DIR='data/20scenarios' \
7
+ DATA.TRAINING_DATA_DIR="/data_zhenghao/datasets/scenarionet/CAT_waymo_hybrid/"
8
+
9
+ """
10
+ import json
11
+ import os
12
+ import pathlib
13
+ import pickle
14
+ import multiprocessing as mp
15
+ from functools import partial
16
+ import tqdm
17
+ import hydra
18
+ import lmdb
19
+ import omegaconf
20
+ import tqdm
21
+
22
+ from scenestreamer.dataset.dataset import SceneStreamerDataset
23
+
24
+ REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
25
+
26
+
27
+ class LMDBBulkWriter:
28
+ def __init__(self, base_path, max_size=1e9):
29
+ """
30
+ Initializes the LMDBBulkWriter to save all data in batches, with map_size for each LMDB file.
31
+ Args:
32
+ base_path: Directory path to save LMDB files.
33
+ max_size: Maximum size of each LMDB file in bytes.
34
+ """
35
+ self.base_path = base_path
36
+ # Create the cache directory if it doesn't exist
37
+ os.makedirs(self.base_path, exist_ok=True)
38
+
39
+ self.max_size = int(max_size) # Set the max LMDB file size (e.g., 1 GB)
40
+ self.current_db_index = 0
41
+ self.lookup = {} # Lookup table to track which LMDB file stores which sample
42
+ self.current_db = self._open_new_lmdb(self.current_db_index)
43
+ self.per_shard_size = 0
44
+
45
+ self.sample_buffer = []
46
+
47
+ def _open_new_lmdb(self, db_index):
48
+ """Opens a new LMDB file for saving samples."""
49
+ db_path = f"{self.base_path}/data_{db_index}.lmdb"
50
+ return lmdb.open(db_path, map_size=self.max_size)
51
+
52
+ def _save_a_batch(self):
53
+
54
+ try:
55
+ # Commit the transaction if we have reached the commit interval
56
+ # if (not hasattr(self, 'txn')) or (self.txn is None):
57
+ # self.txn = self.current_db.begin(write=True) # Start a new transaction
58
+ #
59
+ #
60
+ # for key, data in self.sample_buffer:
61
+ # self.txn.put(key.encode('ascii'), pickle.dumps(data))
62
+ # self.lookup[key] = f"data_{self.current_db_index}.lmdb"
63
+ #
64
+ # if hasattr(self, 'txn') and self.txn:
65
+ # self.txn.commit() # Commit the transaction
66
+ print(f"Saving {len(self.sample_buffer)} samples to data_{self.current_db_index}.lmdb")
67
+ with self.current_db.begin(write=True) as txn:
68
+ for key, data in self.sample_buffer:
69
+ txn.put(key.encode('ascii'), pickle.dumps(data))
70
+ self.lookup[key] = f"data_{self.current_db_index}.lmdb"
71
+
72
+ self.sample_buffer.clear()
73
+
74
+ except lmdb.MapFullError:
75
+
76
+ # If current LMDB file is full, create a new one and retry saving
77
+ self.current_db.close()
78
+ self.current_db_index += 1
79
+ print(f"Creating new LMDB file: data_{self.current_db_index}.lmdb (size: {self.per_shard_size})")
80
+ self.current_db = self._open_new_lmdb(self.current_db_index)
81
+ self._save_a_batch()
82
+ self.per_shard_size = 0
83
+
84
+ def save_sample(self, key, data):
85
+ """Saves a sample to the current LMDB file, switching to a new file if necessary."""
86
+ # Batch writes into a single transaction
87
+ if self.per_shard_size % 100 == 0:
88
+ self._save_a_batch()
89
+ self.sample_buffer.append((key, data))
90
+ self.per_shard_size += 1
91
+
92
+ def close(self):
93
+ self._save_a_batch()
94
+ """Closes the LMDB environment and saves the lookup table as a JSON file."""
95
+ self.current_db.close()
96
+ # Save the lookup table to track the LMDB file where each sample is stored
97
+ with open(f"{self.base_path}/lookup.json", "w") as f:
98
+ json.dump(self.lookup, f)
99
+
100
+
101
+ def preprocess_and_queue_worker(worker_id, config, indices, queue):
102
+ """
103
+ This function runs in each worker to preprocess samples and send them to the write queue.
104
+ The writer process will handle writing to LMDB.
105
+ """
106
+ print(f"Worker {worker_id} started.")
107
+ dataset = SceneStreamerDataset(config, "training")
108
+
109
+ print(f"Worker {worker_id} has {len(dataset)} samples.")
110
+
111
+ # Process and queue each sample assigned to this worker
112
+ if worker_id == 0:
113
+ pbar = tqdm.tqdm(indices, desc="Worker %d" % worker_id)
114
+ else:
115
+ pbar = indices
116
+ print(f"Worker {worker_id} has {len(indices)} samples.")
117
+
118
+ for i in pbar:
119
+ sample = dataset[i] # Access the sample using its index
120
+
121
+ # Simulate some preprocessing (replace with actual preprocessing logic)
122
+ file_name, processed_sample = sample["file_name"], sample
123
+
124
+ # Put the preprocessed sample into the queue to be written by the writer process
125
+ print(f"Worker {worker_id} processed {file_name}")
126
+ queue.put((file_name, processed_sample))
127
+
128
+ # Signal that this worker is done
129
+ # queue.put(None) # 'None' signals that the worker is done
130
+
131
+
132
+ def write_process(queue, base_path, max_size):
133
+ """
134
+ The write process receives samples from the queue and writes them to the LMDB environment.
135
+ """
136
+ writer = LMDBBulkWriter(base_path=base_path, max_size=max_size)
137
+ print("Writer process started.")
138
+
139
+ while True:
140
+
141
+ # Blocking if no data is available
142
+ data = queue.get()
143
+
144
+ if data == 100:
145
+ print("Received 100, stopping writer process.")
146
+ # If 'None' is received, this indicates that a worker has finished
147
+ break
148
+
149
+ if data is None:
150
+ print("Received None, stopping writer process.")
151
+ continue
152
+
153
+ file_name, processed_sample = data
154
+ print(f"Saved {file_name} to LMDB")
155
+ writer.save_sample(file_name, processed_sample)
156
+
157
+ # Close the writer once all workers are done
158
+ writer.close()
159
+
160
+
161
+ @hydra.main(version_base=None, config_path=str(REPO_ROOT / "cfgs"), config_name="1024_gpt.yaml")
162
+ def make_lmdb(config):
163
+ omegaconf.OmegaConf.set_struct(config, False)
164
+ omegaconf.OmegaConf.set_struct(config, True)
165
+
166
+ dataset = SceneStreamerDataset(config, "training")
167
+ folder = pathlib.Path(dataset.data_dir)
168
+ folder = folder / "cache"
169
+ folder.mkdir(parents=True, exist_ok=False)
170
+
171
+ # Initialize the LMDBBulkWriter
172
+ print("Saving data to LMDB folder:", folder.absolute())
173
+
174
+ # num_workers = mp.cpu_count()
175
+ num_workers = 2
176
+
177
+ dataset_size = len(dataset)
178
+ indices = list(range(dataset_size))
179
+ chunk_size = dataset_size // num_workers
180
+
181
+ # Split the indices into chunks, one for each worker
182
+ chunked_indices = [indices[i * chunk_size:(i + 1) * chunk_size] for i in range(num_workers)]
183
+ # The final chunk may have more samples if the dataset size is not divisible by the number of workers.
184
+ chunked_indices[0].extend(indices[num_workers * chunk_size:])
185
+
186
+ # Create a multiprocessing queue
187
+ queue = mp.Queue()
188
+
189
+ # Create and start the writer process
190
+ writer_process = mp.Process(target=write_process, args=(queue, folder, 1e10))
191
+
192
+ writer_process.start()
193
+
194
+ # Create a multiprocessing pool for parallel processing (preprocessing)
195
+ pool = mp.Pool(num_workers)
196
+
197
+ results = []
198
+ # Start each worker process, passing its chunk of indices
199
+ for worker_id, worker_indices in enumerate(chunked_indices):
200
+ print(f"Starting worker {worker_id} with {len(worker_indices)} samples.")
201
+ result = pool.apply_async(preprocess_and_queue_worker, args=(worker_id, config, worker_indices, queue))
202
+ results.append(result)
203
+
204
+ # Wait for all worker processes to complete
205
+ # for result in results:
206
+ # result.get() # This will block until the worker completes its task
207
+ # preprocess_and_queue_worker(0, config, chunked_indices[0], queue)
208
+ pool.close()
209
+ print("Waiting for workers to finish...")
210
+ pool.join()
211
+ print("All workers finished.")
212
+
213
+ # Signal the writer process to stop (send 'None' once all workers are done)
214
+ queue.put(100)
215
+ # Wait for the writer process to finish
216
+ writer_process.join()
217
+
218
+
219
+ @hydra.main(version_base=None, config_path=str(REPO_ROOT / "cfgs"), config_name="motion_default.yaml")
220
+ def debug(config):
221
+ omegaconf.OmegaConf.set_struct(config, False)
222
+ omegaconf.OmegaConf.set_struct(config, True)
223
+ dataset = SceneStreamerDataset(config, "training")
224
+ folder = pathlib.Path(dataset.data_dir)
225
+ folder = folder / "cache"
226
+ folder.mkdir(parents=True, exist_ok=True)
227
+ for i, sample in enumerate(tqdm.tqdm(dataset, total=len(dataset), desc="Scenarios")):
228
+ file_name = sample["file_name"]
229
+
230
+
231
+ if __name__ == '__main__':
232
+ make_lmdb()
233
+ # debug()
scenestreamer/dataset/preprocess_action_label.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from shapely.geometry import Polygon
3
+
4
+ from scenestreamer.utils import utils
5
+
6
+ INVALID_VALUE = -10000
7
+
8
+
9
+ class TurnAction:
10
+ STOP = 0
11
+ KEEP_STRAIGHT = 1
12
+ TURN_LEFT = 2
13
+ TURN_RIGHT = 3
14
+ U_TURN = 4
15
+
16
+ num_actions = 5
17
+
18
+
19
+ class AccelerationAction:
20
+ STOP = 0
21
+ KEEP_SPEED = 1
22
+ SPEED_UP = 2
23
+ SLOW_DOWN = 3
24
+
25
+ num_actions = 4
26
+
27
+
28
+ class SafetyAction:
29
+ SAFE = 0
30
+ COLLISION = 1
31
+ num_actions = 2
32
+
33
+
34
+ def cal_polygon_contour(x, y, theta, width, length):
35
+
36
+ left_front_x = x + 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta)
37
+ left_front_y = y + 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta)
38
+ left_front = np.column_stack((left_front_x, left_front_y))
39
+
40
+ right_front_x = x + 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta)
41
+ right_front_y = y + 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta)
42
+ right_front = np.column_stack((right_front_x, right_front_y))
43
+
44
+ right_back_x = x - 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta)
45
+ right_back_y = y - 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta)
46
+ right_back = np.column_stack((right_back_x, right_back_y))
47
+
48
+ left_back_x = x - 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta)
49
+ left_back_y = y - 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta)
50
+ left_back = np.column_stack((left_back_x, left_back_y))
51
+
52
+ polygon_contour = np.concatenate(
53
+ (left_front[:, None, :], right_front[:, None, :], right_back[:, None, :], left_back[:, None, :]), axis=1
54
+ )
55
+
56
+ return polygon_contour
57
+
58
+
59
+ def detect_collision(contour_list1, mask1, contour_list2, mask2):
60
+ collision_detected = []
61
+ assert len(contour_list1) == len(contour_list2)
62
+
63
+ for i in range(len(contour_list1)):
64
+ if mask1[i] and mask2[i]:
65
+ poly1 = Polygon(contour_list1[i])
66
+ poly2 = Polygon(contour_list2[i])
67
+
68
+ if poly1.intersects(poly2):
69
+ collision_detected.append(True)
70
+ else:
71
+ collision_detected.append(False)
72
+ else:
73
+ collision_detected.append(False)
74
+
75
+ return collision_detected
76
+
77
+
78
+ def get_direction_action_from_trajectory_batch(traj, mask, dt=0.1, ooi=None):
79
+ U_TURN_DEG = 115
80
+ LEFT_TURN_DEG = 25
81
+ RIGHT_TURN_DEG = -25
82
+ STOP_SPEED = 0.06
83
+
84
+ assert traj.ndim == 3
85
+ traj_diff = traj[1:] - traj[:-1]
86
+ mask_diff = mask[1:] & mask[:-1]
87
+
88
+ displacement = np.linalg.norm(traj_diff, axis=-1)
89
+
90
+ mask_diff_stop = mask_diff & (displacement > 0.1)
91
+
92
+ pred_angles = np.arctan2(traj_diff[..., 1], traj_diff[..., 0])
93
+ pred_angles_diff = utils.wrap_to_pi(pred_angles[1:] - pred_angles[:-1])
94
+
95
+ # It's meaning less to compute heading for a stopped vehicle. So mask them out!
96
+ mask_diff_diff = mask_diff_stop[1:] & mask_diff_stop[:-1]
97
+ # Note that we should not wrap to pi here because the sign is important.
98
+ accumulated_heading_change_rad = (pred_angles_diff * mask_diff_diff).sum(axis=0)
99
+ accumulated_heading_change_deg = np.degrees(accumulated_heading_change_rad)
100
+
101
+ # print("accumulated_heading_change_deg: ", list(zip(ooi, accumulated_heading_change_deg)))
102
+
103
+ speed = displacement / dt
104
+ avg_speed = utils.masked_average_numpy(speed, mask_diff, dim=0)
105
+
106
+ actions = np.zeros(accumulated_heading_change_deg.shape, dtype=int)
107
+ actions.fill(TurnAction.KEEP_STRAIGHT)
108
+ actions[accumulated_heading_change_deg > LEFT_TURN_DEG] = TurnAction.TURN_LEFT
109
+ actions[accumulated_heading_change_deg < RIGHT_TURN_DEG] = TurnAction.TURN_RIGHT
110
+ actions[accumulated_heading_change_deg > U_TURN_DEG] = TurnAction.U_TURN
111
+ actions[accumulated_heading_change_deg < -U_TURN_DEG] = TurnAction.U_TURN
112
+ actions[avg_speed < STOP_SPEED] = TurnAction.STOP
113
+ return actions
114
+
115
+
116
+ def get_acce_action_from_trajectory_batch(batch_trajs, mask, ooi=None, dt=0.1):
117
+
118
+ SPEEDUP_ACCEL = 0.3
119
+ SPEEDDOWN_ACCEL = -0.3
120
+ STOP_SPEED = 0.06
121
+
122
+ traj_diff = batch_trajs[1:] - batch_trajs[:-1] # (T, A, 2)
123
+ mask_diff = mask[1:] & mask[:-1] # (T, A)
124
+
125
+ speed = np.linalg.norm(traj_diff, axis=-1) / dt # (T, A)
126
+
127
+ speed_change = speed[1:] - speed[:-1]
128
+ mask_diff_diff = mask_diff[1:] & mask_diff[:-1]
129
+
130
+ absolute_avg_speed = utils.masked_average_numpy(speed, mask_diff, dim=0)
131
+
132
+ accumulated_speed_change = (speed_change * mask_diff_diff).sum(0)
133
+
134
+ init_speed_ind = mask_diff.argmax(axis=0)
135
+ init_speed = np.take_along_axis(speed, init_speed_ind[None, :], axis=0)[0]
136
+
137
+ speed_change_ratio = accumulated_speed_change / np.maximum(init_speed, STOP_SPEED)
138
+
139
+ # print("speed_change_ratio: ", list(zip(ooi, speed_change_ratio)))
140
+
141
+ actions = np.zeros(speed_change_ratio.shape, dtype=int)
142
+
143
+ actions.fill(AccelerationAction.KEEP_SPEED)
144
+ actions[speed_change_ratio > SPEEDUP_ACCEL] = AccelerationAction.SPEED_UP
145
+ actions[speed_change_ratio < SPEEDDOWN_ACCEL] = AccelerationAction.SLOW_DOWN
146
+ actions[absolute_avg_speed <= STOP_SPEED] = AccelerationAction.STOP # if stop
147
+
148
+ return actions
149
+
150
+
151
+ def get_safety_action_from_sdc_adv(data_dict, adv_id, sdc_id):
152
+
153
+ contours = []
154
+ for agent_id in [adv_id, sdc_id]:
155
+ traj = data_dict["decoder/agent_position"][:91, agent_id, :] # (91, 3)
156
+ length = data_dict["decoder/agent_shape"][:91, agent_id, 0]
157
+ width = data_dict["decoder/agent_shape"][:91, agent_id, 1]
158
+ theta = data_dict['decoder/agent_heading'][:91, agent_id] # (91, ) # in pi
159
+ mask = data_dict['decoder/agent_valid_mask'][:91, agent_id] # (91,)
160
+
161
+ poly = cal_polygon_contour(traj[:, 0], traj[:, 1], theta, width, length)
162
+ contours.append(poly)
163
+
164
+ sdc_mask = data_dict['decoder/agent_valid_mask'][:, sdc_id] # (91,)
165
+ adv_mask = data_dict['decoder/agent_valid_mask'][:, adv_id]
166
+ adv_contour = contours[0]
167
+ sdc_contour = contours[1]
168
+
169
+ collision_detected = detect_collision(adv_contour, adv_mask, sdc_contour, sdc_mask)
170
+
171
+ # instead of loading a dict which saves all collision scenario, we could simply detect all agents' potential collision
172
+ return collision_detected
173
+
174
+
175
+ def get_safety_action_from_trajectory_batch(data_dict, track_agent_indicies):
176
+
177
+ safety_actions = np.zeros((track_agent_indicies.shape[0], ), dtype=int) # plus sdc
178
+
179
+ contours = []
180
+ for agent1_id in track_agent_indicies:
181
+ traj = data_dict["decoder/agent_position"][:, agent1_id, :] # (91, 3)
182
+ length = data_dict["decoder/agent_shape"][:, agent1_id, 0]
183
+ width = data_dict["decoder/agent_shape"][:, agent1_id, 1]
184
+ theta = data_dict['decoder/agent_heading'][:, agent1_id] # (91, ) # in pi
185
+ mask = data_dict['decoder/agent_valid_mask'][:, agent1_id] # (91,)
186
+ poly = cal_polygon_contour(traj[:, 0], traj[:, 1], theta, width, length)
187
+ contours.append(poly)
188
+
189
+ for i in range(track_agent_indicies.shape[0] - 1):
190
+ for j in range(i + 1, track_agent_indicies.shape[0]):
191
+ mask_1 = data_dict['decoder/agent_valid_mask'][:, track_agent_indicies[i]] # (91,)
192
+ mask_2 = data_dict['decoder/agent_valid_mask'][:, track_agent_indicies[j]]
193
+ collision_detected = detect_collision(contours[i], mask_1, contours[j], mask_2)
194
+
195
+ if any(collision_detected):
196
+ # print(f"Collision between {i} and {j} happen at step: {np.array(collision_detected).nonzero()}")
197
+ safety_actions[i] = 1 # Label collisions for OOIs now. Later we will build a larger dict.
198
+ safety_actions[j] = 1
199
+
200
+ # instead of loading a dict which saves all collision scenario, we could simply detect all agents' potential collision
201
+ return safety_actions
202
+
203
+
204
+ def prepare_action_label(*, data_dict, dt, mask_probability, config):
205
+ """
206
+ mask_probability: the probability of masking the label. Should be around 0.05 or 0.1. Can't be too high.
207
+ """
208
+ ooi_ind = data_dict["decoder/labeled_agent_id"]
209
+ ooi_pos = utils.extract_data_by_agent_indices(data_dict["decoder/agent_position"], ooi_ind, agent_dim=1)[..., :2]
210
+ ooi_valid = utils.extract_data_by_agent_indices(
211
+ data_dict["decoder/agent_valid_mask"], ooi_ind, agent_dim=1
212
+ ) # (T, A)
213
+
214
+ # TODO: hardcoded here for now and we assume you can access GT trajectory. This won't work with test dataset.
215
+ assert ooi_pos.shape[0] == 91
216
+ assert ooi_valid.shape[0] == 91
217
+
218
+ # get the degree, acceleration, speed
219
+ turn_actions = get_direction_action_from_trajectory_batch(traj=ooi_pos, mask=ooi_valid, dt=dt, ooi=ooi_ind)
220
+ acce_actions = get_acce_action_from_trajectory_batch(ooi_pos, ooi_valid, dt=dt, ooi=ooi_ind)
221
+
222
+ # Rescatter labels to decoder-agent indices
223
+ assert config.TRAINING.PREDICT_ALL_AGENTS
224
+ B = data_dict["decoder/agent_valid_mask"].shape[1]
225
+
226
+ full_turn_actions = np.full((B, ), -1, dtype=int)
227
+ full_acce_actions = np.full((B, ), -1, dtype=int)
228
+
229
+ label_mask = np.random.binomial(1, mask_probability, size=len(ooi_ind))
230
+ label_invalid_mask = label_mask == 1
231
+
232
+ turn_actions[label_invalid_mask] = -1
233
+ acce_actions[label_invalid_mask] = -1
234
+
235
+ full_turn_actions[ooi_ind] = turn_actions
236
+ full_acce_actions[ooi_ind] = acce_actions
237
+
238
+ data_dict["decoder/label_turning"] = full_turn_actions
239
+ data_dict["decoder/label_acceleration"] = full_acce_actions
240
+
241
+ return data_dict
242
+
243
+
244
+ def prepare_safety_label(*, data_dict, dt, mask_probability, config):
245
+ ooi_ind = data_dict["decoder/labeled_agent_id"]
246
+
247
+ ooi_pos = utils.extract_data_by_agent_indices(data_dict["decoder/agent_position"], ooi_ind, agent_dim=1)[..., :2]
248
+ ooi_valid = utils.extract_data_by_agent_indices(
249
+ data_dict["decoder/agent_valid_mask"], ooi_ind, agent_dim=1
250
+ ) # (T, A)
251
+
252
+ # TODO: hardcoded here for now and we assume you can access GT trajectory. This won't work with test dataset.
253
+ assert ooi_pos.shape[0] == 91
254
+ assert ooi_valid.shape[0] == 91
255
+
256
+ safety_actions = get_safety_action_from_trajectory_batch(data_dict, ooi_ind)
257
+
258
+ # Rescatter labels to decoder-agent indices
259
+ assert config.TRAINING.PREDICT_ALL_AGENTS
260
+ num_modeled_agents = data_dict["decoder/agent_valid_mask"].shape[1]
261
+
262
+ full_safety_actions = np.full((num_modeled_agents, ), -1, dtype=int)
263
+
264
+ label_mask = np.random.binomial(1, mask_probability, size=len(ooi_ind))
265
+ label_invalid_mask = label_mask == 1
266
+
267
+ label_invalid_mask[safety_actions == 1] = False # We don't mask collision labels
268
+
269
+ safety_actions[label_invalid_mask] = -1
270
+
271
+ full_safety_actions[ooi_ind] = safety_actions
272
+
273
+ data_dict["decoder/label_safety"] = full_safety_actions
274
+
275
+ return data_dict
276
+
277
+
278
+ if __name__ == '__main__':
279
+ scenario_dir = "/Users/claire_liu/validation_interactive_0/cat_reconstructed/sd_reconstructed_v0_ScenarioMap-21.pkl"
280
+ cat_dir = "/Users/claire_liu/validation_interactive_0/save.pkl"
281
+
282
+ import pickle
283
+
284
+ with open(scenario_dir, 'rb') as f:
285
+ scenario_data = pickle.load(f)
286
+ f.close()
287
+
288
+ with open(cat_dir, 'rb') as ff:
289
+ cat_dict = pickle.load(ff)
290
+ ff.close()
291
+
292
+ batch_labels = get_3d_action_label(scenario_data, cat_dict)
293
+ print(batch_labels)
scenestreamer/dataset/preprocessor.py ADDED
The diff for this file is too large to render. See raw diff
 
scenestreamer/dataset/scenarionet_utils.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from scenestreamer.utils import wrap_to_pi, rotate
4
+
5
+
6
+ def overwrite_gt_to_pred_field(data_dict):
7
+ import copy
8
+ new_data_dict = copy.deepcopy(data_dict)
9
+ T, N, _ = data_dict["decoder/agent_position"].shape
10
+
11
+ new_data_dict["decoder/reconstructed_position"] = np.zeros((96, N, 2)).astype(np.float32)
12
+ new_data_dict["decoder/reconstructed_valid_mask"] = np.zeros((
13
+ 96,
14
+ N,
15
+ )).astype(bool)
16
+ new_data_dict["decoder/reconstructed_heading"] = np.zeros((
17
+ 96,
18
+ N,
19
+ )).astype(np.float32)
20
+ new_data_dict["decoder/reconstructed_velocity"] = np.zeros((96, N, 2)).astype(np.float32)
21
+
22
+ for id in range(N): # overwrite all agents
23
+ traj = new_data_dict["decoder/agent_position"][:91, id, :2].astype(np.float32)
24
+ traj_mask = new_data_dict["decoder/agent_valid_mask"][:91, id].astype(bool)
25
+ theta = new_data_dict['decoder/agent_heading'][:91, id].astype(np.float32)
26
+ vel = new_data_dict['decoder/agent_velocity'][:91, id].astype(np.float32)
27
+
28
+ new_data_dict["decoder/reconstructed_position"][:91, id, :2] = traj
29
+ # new_data_dict["decoder/reconstructed_position"][:91, id, 2] = 0.0
30
+ new_data_dict["decoder/reconstructed_valid_mask"][:91, id] = traj_mask
31
+ # print(traj_mask)
32
+ new_data_dict["decoder/reconstructed_heading"][:91, id] = theta
33
+ new_data_dict["decoder/reconstructed_velocity"][:91, id] = vel
34
+
35
+ return new_data_dict
36
+
37
+
38
+ def create_new_adv(data_dict):
39
+ ego_id = data_dict["decoder/sdc_index"]
40
+
41
+ ego_traj = data_dict["decoder/agent_position"][:, ego_id]
42
+ ego_heading = data_dict["decoder/agent_heading"][:, ego_id]
43
+ ego_velocity = data_dict["decoder/agent_velocity"][:, ego_id]
44
+ ego_shape = data_dict["decoder/agent_shape"][:, ego_id]
45
+ ego_mask = data_dict["decoder/agent_valid_mask"][:, ego_id]
46
+
47
+ last_valid_step = np.where(ego_mask)[0][-1]
48
+
49
+ # Create a new ADV at the final step.
50
+
51
+ adv_mask = np.zeros_like(ego_mask)
52
+ adv_mask[:last_valid_step + 1] = True
53
+
54
+ adv_traj = np.zeros_like(ego_traj)
55
+ adv_heading = np.zeros_like(ego_heading)
56
+ adv_velocity = np.zeros_like(ego_velocity)
57
+ adv_shape = np.zeros_like(ego_shape)
58
+
59
+ # Copy the final pos/head/vel/shape of ego
60
+ adv_traj[last_valid_step] = ego_traj[last_valid_step] + np.random.normal(loc=0.0, scale=0.5, size=3)
61
+ adv_heading[last_valid_step] = ego_heading[last_valid_step] + np.random.normal(loc=0.0, scale=0.1, size=1)
62
+ adv_velocity[last_valid_step] = ego_velocity[last_valid_step] + np.random.normal(loc=0.0, scale=0.5, size=2)
63
+
64
+ for i in range(data_dict["decoder/agent_shape"].shape[0]):
65
+ adv_shape[i] = ego_shape[last_valid_step]
66
+
67
+ # Insert data back:
68
+ data_dict["decoder/agent_position"] = np.concatenate(
69
+ [data_dict["decoder/agent_position"], adv_traj[:, None]], axis=1
70
+ )
71
+ data_dict["decoder/agent_heading"] = np.concatenate(
72
+ [data_dict["decoder/agent_heading"], adv_heading[:, None]], axis=1
73
+ )
74
+ data_dict["decoder/agent_velocity"] = np.concatenate(
75
+ [data_dict["decoder/agent_velocity"], adv_velocity[:, None]], axis=1
76
+ )
77
+ # data_dict["decoder/agent_shape"] = np.concatenate([data_dict["decoder/agent_shape"], adv_shape[:, None]], axis=1)
78
+
79
+ data_dict["decoder/agent_shape"] = np.concatenate([data_dict["decoder/agent_shape"], adv_shape[:, None]], axis=1)
80
+
81
+ data_dict["decoder/agent_valid_mask"] = np.concatenate(
82
+ [data_dict["decoder/agent_valid_mask"], adv_mask[:, None]], axis=1
83
+ )
84
+
85
+ data_dict["decoder/current_agent_shape"] = np.concatenate(
86
+ [data_dict["decoder/current_agent_shape"], data_dict["decoder/current_agent_shape"][ego_id:ego_id + 1]], axis=0
87
+ )
88
+ data_dict["decoder/agent_type"] = np.concatenate(
89
+ [data_dict["decoder/agent_type"], data_dict["decoder/agent_type"][ego_id:ego_id + 1]], axis=0
90
+ )
91
+ data_dict["decoder/agent_id"] = np.concatenate(
92
+ [data_dict["decoder/agent_id"], [len(data_dict["decoder/agent_id"])]], axis=0
93
+ )
94
+
95
+ # Add ADV into OOI:
96
+ data_dict["decoder/object_of_interest_id"] = np.concatenate(
97
+ [data_dict["decoder/object_of_interest_id"], [len(data_dict["decoder/agent_id"]) - 1]], axis=0
98
+ )
99
+
100
+ # Deal with some thing for forward prediction:
101
+ data_dict["decoder/current_agent_valid_mask"] = np.concatenate(
102
+ [data_dict["decoder/current_agent_valid_mask"], [1]], axis=0
103
+ )
104
+
105
+ print("====================================")
106
+ print(
107
+ "The new ADV is created at the final step {}, it's ID is: {}".format(
108
+ last_valid_step,
109
+ len(data_dict["decoder/agent_id"]) - 1
110
+ )
111
+ )
112
+ print("====================================")
113
+
114
+ return data_dict
115
+
116
+
117
+ def overwrite_to_scenario_description(output_dict_mode, original_SD, ooi=None, adv_id=None):
118
+ # overwrite original SD with all predicted ooi trajectories included
119
+ # import pdb; pdb.set_trace()
120
+ if not ooi:
121
+ ooi = output_dict_mode['decoder/agent_id'] # overwrite all agents
122
+ sdc_track_name = original_SD['metadata']['sdc_id']
123
+ adv_track_name = str(output_dict_mode['decoder/track_name'][int(adv_id)].item())
124
+
125
+ for id in ooi:
126
+ agent_track_name = str(output_dict_mode['decoder/track_name'][id].item())
127
+
128
+ # begin to overwrite original scenario_data
129
+ agent_traj = output_dict_mode["decoder/agent_position"][:91, id, ]
130
+ agent_heading = output_dict_mode["decoder/agent_heading"][:91, id]
131
+ agent_vel = output_dict_mode["decoder/agent_velocity"][:91, id]
132
+ agent_traj_mask = output_dict_mode["decoder/agent_valid_mask"][:91, id]
133
+
134
+ # modify adv info
135
+ # agent_z = original_SD['tracks'][agent_track_name]['state']['position'][10, 2] # fill the z-axis
136
+ # agent_traj_z = np.full((91, 1), agent_z)
137
+ # agent_new_traj = np.concatenate([agent_traj, agent_traj_z], axis=1)
138
+ # print("new_traj:", agent_new_traj.shape)
139
+ original_SD['tracks'][agent_track_name]['state']['position'] = agent_traj
140
+ original_SD['tracks'][agent_track_name]['state']['velocity'] = agent_vel
141
+ original_SD['tracks'][agent_track_name]['state']['heading'] = agent_heading
142
+ original_SD['tracks'][agent_track_name]['state']['valid'] = agent_traj_mask
143
+
144
+ length = original_SD['tracks'][agent_track_name]['state']['length'][10]
145
+ width = original_SD['tracks'][agent_track_name]['state']['width'][10]
146
+ height = original_SD['tracks'][agent_track_name]['state']['height'][10]
147
+ original_SD['tracks'][agent_track_name]['state']['length'] = np.full((91, ), length)
148
+ original_SD['tracks'][agent_track_name]['state']['width'] = np.full((91, ), width)
149
+ original_SD['tracks'][agent_track_name]['state']['height'] = np.full((91, ), height)
150
+
151
+ original_SD['metadata']['selected_adv_id'] = adv_track_name
152
+
153
+ return original_SD
154
+
155
+
156
+ def overwrite_to_scenario_description_new_agent(output_dict_mode, original_SD, ooi=None):
157
+ # overwrite original SD with all predicted ooi trajectories included
158
+ ooi = output_dict_mode['decoder/agent_id'] # overwrite all agents
159
+
160
+ adv_track_name = 'new_adv_agent'
161
+ original_SD['tracks'][adv_track_name] = {'state': {}, 'type': 'VEHICLE', 'metadata': {}}
162
+ sdc_track_name = original_SD['metadata']['sdc_id']
163
+
164
+ for id in ooi:
165
+ if id == ooi[-1]:
166
+ agent_track_name = 'new_adv_agent'
167
+ else:
168
+ agent_track_name = str(output_dict_mode['decoder/track_name'][id].item())
169
+
170
+ # begin to overwrite original scenario_data
171
+ agent_traj = output_dict_mode["decoder/agent_position"][:, id, ]
172
+ agent_heading = output_dict_mode["decoder/agent_heading"][:, id]
173
+ agent_vel = output_dict_mode["decoder/agent_velocity"][:, id]
174
+ agent_traj_mask = output_dict_mode["decoder/agent_valid_mask"][:, id]
175
+
176
+ # modify adv info
177
+ # agent_z = original_SD['tracks'][agent_track_name]['state']['position'][10, 2] # fill the z-axis
178
+ # agent_traj_z = np.full((91, 1), agent_z)
179
+ # agent_new_traj = np.concatenate([agent_traj, agent_traj_z], axis=1)
180
+ # print("new_traj:", agent_new_traj.shape)
181
+ original_SD['tracks'][agent_track_name]['state']['position'] = agent_traj
182
+
183
+ original_SD['tracks'][agent_track_name]['state']['velocity'] = agent_vel
184
+ original_SD['tracks'][agent_track_name]['state']['heading'] = agent_heading
185
+ original_SD['tracks'][agent_track_name]['state']['valid'] = agent_traj_mask
186
+
187
+ length = original_SD['tracks'][sdc_track_name]['state']['length'][10]
188
+ width = original_SD['tracks'][sdc_track_name]['state']['width'][10]
189
+ height = original_SD['tracks'][sdc_track_name]['state']['height'][10]
190
+ original_SD['tracks'][agent_track_name]['state']['length'] = np.full((91, ), length)
191
+ original_SD['tracks'][agent_track_name]['state']['width'] = np.full((91, ), width)
192
+ original_SD['tracks'][agent_track_name]['state']['height'] = np.full((91, ), height)
193
+
194
+ original_SD['tracks'][adv_track_name]['metadata']['dataset'] = 'waymo'
195
+ original_SD['tracks'][adv_track_name]['metadata']['object_id'] = 'new_adv_agent'
196
+ original_SD['tracks'][adv_track_name]['metadata']['track_length'] = 91
197
+ original_SD['tracks'][adv_track_name]['metadata']['type'] = 'VEHICLE'
198
+ original_SD['metadata']['new_adv_id'] = 'new_adv_agent'
199
+ original_SD['metadata']['objects_of_interest'].append('new_adv_agent')
200
+ tracks_length = len(list(original_SD['tracks'].keys()))
201
+ original_SD['metadata']['tracks_to_predict']['new_adv_agent'] = {
202
+ 'difficulty': 0,
203
+ 'object_type': 'VEHICLE',
204
+ 'track_id': 'new_adv_agent',
205
+ 'track_index': tracks_length - 1
206
+ }
207
+
208
+ return original_SD
209
+
210
+
211
+ def transform_to_global_coordinate(data_dict):
212
+ map_center = data_dict["metadata/map_center"].reshape(-1, 1, 3) # (1,1,3)
213
+ assert "decoder/agent_position" in data_dict, "Have you set EVALUATION.PREDICT_ALL_AGENTS to False?"
214
+ T, N, _ = data_dict["decoder/agent_position"].shape
215
+ assert data_dict["decoder/agent_position"].ndim == 3
216
+ data_dict["decoder/agent_position"] += map_center
217
+
218
+ return data_dict
219
+
220
+
221
+ def _overwrite_datadict_all_agents(data_dict):
222
+ import copy
223
+ new_data_dict = copy.deepcopy(data_dict)
224
+
225
+ T, N, _ = data_dict["decoder/reconstructed_position"].shape
226
+
227
+ for id in range(N): # overwrite all agents
228
+ traj = data_dict["decoder/reconstructed_position"][:91, id, ]
229
+ traj_mask = data_dict["decoder/reconstructed_valid_mask"][:91, id]
230
+ theta = data_dict['decoder/reconstructed_heading'][:91, id]
231
+ vel = data_dict['decoder/reconstructed_velocity'][:91, id]
232
+
233
+ new_data_dict["decoder/agent_position"][:, id, :2] = traj
234
+ new_data_dict["decoder/agent_position"][:, id, 2] = 0.0
235
+ new_data_dict["decoder/agent_valid_mask"][:, id] = traj_mask
236
+ new_data_dict["decoder/agent_heading"][:, id] = theta
237
+ new_data_dict["decoder/agent_velocity"][:, id] = vel
238
+
239
+ return new_data_dict