Spaces:
Sleeping
Sleeping
pengzhenghao commited on
Commit ·
89f8755
1
Parent(s): b8a7066
Set up self-contained Gradio Space
Browse filesBundle the app code and tiny demo dataset so the Hugging Face Space can boot directly into the SceneStreamer demo with sensible headless defaults.
Made-with: Cursor
This view is limited to 50 files because it contains too many changes. See raw diff
- .gitignore +12 -0
- README.md +28 -4
- app.py +58 -0
- cfgs/motion_default.yaml +211 -0
- cfgs/scenestreamer-base-large.yaml +96 -0
- cfgs/scenestreamer-base-small.yaml +96 -0
- cfgs/scenestreamer-base-xl.yaml +97 -0
- cfgs/scenestreamer-full-large-nors.yaml +99 -0
- cfgs/scenestreamer-full-large.yaml +99 -0
- cfgs/scenestreamer-full-small.yaml +96 -0
- cfgs/scenestreamer-full-xl.yaml +100 -0
- data/20scenarios/cache/sd_training.tfrecord-00000-of-01000_101d4e5775093d0c.pkl +3 -0
- data/20scenarios/cache/sd_training.tfrecord-00000-of-01000_16d4f837fb57ff1.pkl +3 -0
- data/20scenarios/cache/sd_training.tfrecord-00000-of-01000_17100879b93ceb61.pkl +3 -0
- data/20scenarios/cache/sd_training.tfrecord-00000-of-01000_c403d5992cab9e0.pkl +3 -0
- data/20scenarios/process.ipynb +128 -0
- data/20scenarios/sd_training.tfrecord-00000-of-01000_101d4e5775093d0c.pkl +3 -0
- data/20scenarios/sd_training.tfrecord-00000-of-01000_10c3969f1eb158d.pkl +3 -0
- data/20scenarios/sd_training.tfrecord-00000-of-01000_1109b0038ed8f25a.pkl +3 -0
- data/20scenarios/sd_training.tfrecord-00000-of-01000_116d257e98878d94.pkl +3 -0
- data/20scenarios/sd_training.tfrecord-00000-of-01000_11cdace2c1445900.pkl +3 -0
- data/20scenarios/sd_training.tfrecord-00000-of-01000_12a40f114f1ec5fc.pkl +3 -0
- data/20scenarios/sd_training.tfrecord-00000-of-01000_149f682e19454efa.pkl +3 -0
- data/20scenarios/sd_training.tfrecord-00000-of-01000_16361c8c522cf0e.pkl +3 -0
- data/20scenarios/sd_training.tfrecord-00000-of-01000_1663adf01133d82.pkl +3 -0
- data/20scenarios/sd_training.tfrecord-00000-of-01000_16b7c9a8ae6e89f1.pkl +3 -0
- data/20scenarios/sd_training.tfrecord-00000-of-01000_16d4f837fb57ff1.pkl +3 -0
- data/20scenarios/sd_training.tfrecord-00000-of-01000_17100879b93ceb61.pkl +3 -0
- data/20scenarios/sd_training.tfrecord-00000-of-01000_17fa718cd1251a26.pkl +3 -0
- data/20scenarios/sd_training.tfrecord-00000-of-01000_18840a098288507f.pkl +3 -0
- data/20scenarios/sd_training.tfrecord-00000-of-01000_18f5ce249ee2e949.pkl +3 -0
- data/20scenarios/sd_training.tfrecord-00000-of-01000_1a7143a44e480ca6.pkl +3 -0
- data/20scenarios/sd_training.tfrecord-00000-of-01000_1a8cc570d620bd31.pkl +3 -0
- data/20scenarios/sd_training.tfrecord-00000-of-01000_2a1e44d405a6833f.pkl +3 -0
- data/20scenarios/sd_training.tfrecord-00000-of-01000_8a346109094cd5aa.pkl +3 -0
- data/20scenarios/sd_training.tfrecord-00000-of-01000_c403d5992cab9e0.pkl +3 -0
- packages.txt +7 -0
- pyproject.toml +80 -0
- requirements.txt +1 -0
- scenestreamer/__init__.py +0 -0
- scenestreamer/cli.py +293 -0
- scenestreamer/clustering.sh +7 -0
- scenestreamer/dataset/__init__.py +0 -0
- scenestreamer/dataset/constants.py +44 -0
- scenestreamer/dataset/datamodule.py +49 -0
- scenestreamer/dataset/dataset.py +630 -0
- scenestreamer/dataset/make_lmdb.py +233 -0
- scenestreamer/dataset/preprocess_action_label.py +293 -0
- scenestreamer/dataset/preprocessor.py +0 -0
- scenestreamer/dataset/scenarionet_utils.py +239 -0
.gitignore
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv/
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
| 4 |
+
*.pyo
|
| 5 |
+
.DS_Store
|
| 6 |
+
artifacts/
|
| 7 |
+
outputs/
|
| 8 |
+
lightning_logs/
|
| 9 |
+
wandb/
|
| 10 |
+
scenestreamer/outputs/
|
| 11 |
+
scenestreamer/lightning_logs/
|
| 12 |
+
scenestreamer/eval/outputs/
|
README.md
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
---
|
| 2 |
title: SceneStreamer
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 6.9.0
|
| 8 |
app_file: app.py
|
|
@@ -10,4 +10,28 @@ pinned: false
|
|
| 10 |
license: mit
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
title: SceneStreamer
|
| 3 |
+
emoji: 🚗
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 6.9.0
|
| 8 |
app_file: app.py
|
|
|
|
| 10 |
license: mit
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# SceneStreamer Space
|
| 14 |
+
|
| 15 |
+
This Space hosts the interactive Gradio demo for `SceneStreamer`.
|
| 16 |
+
|
| 17 |
+
What is included here:
|
| 18 |
+
|
| 19 |
+
- the Gradio app entrypoint in `app.py`
|
| 20 |
+
- the SceneStreamer package code needed by the demo
|
| 21 |
+
- a tiny bundled ScenarioNet subset in `data/20scenarios`
|
| 22 |
+
|
| 23 |
+
Default behavior:
|
| 24 |
+
|
| 25 |
+
- the app loads the bundled demo subset automatically
|
| 26 |
+
- the model checkpoint is fetched from the Hugging Face Hub by default
|
| 27 |
+
- `SCENESTREAMER_DEVICE` defaults to `cpu`
|
| 28 |
+
|
| 29 |
+
Optional Space variables:
|
| 30 |
+
|
| 31 |
+
- `SCENESTREAMER_DATASET_DIR`
|
| 32 |
+
- `SCENESTREAMER_HF_REPO`
|
| 33 |
+
- `SCENESTREAMER_HF_FILE`
|
| 34 |
+
- `SCENESTREAMER_CKPT`
|
| 35 |
+
- `SCENESTREAMER_DEVICE`
|
| 36 |
+
|
| 37 |
+
If the app shows the setup screen instead of the demo, the dataset path is missing or the demo subset was not uploaded with the repo.
|
app.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
os.environ.setdefault("MPLBACKEND", "Agg")
|
| 7 |
+
os.environ.setdefault("PYGAME_HIDE_SUPPORT_PROMPT", "1")
|
| 8 |
+
os.environ.setdefault("SDL_AUDIODRIVER", "dummy")
|
| 9 |
+
os.environ.setdefault("SDL_VIDEODRIVER", "dummy")
|
| 10 |
+
|
| 11 |
+
import gradio as gr
|
| 12 |
+
|
| 13 |
+
from scenestreamer.gradio_ui.demo_app import DEFAULT_HF_FILE, DEFAULT_HF_REPO, build_demo
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _build_space_demo() -> gr.Blocks:
|
| 17 |
+
dataset_dir = os.environ.get("SCENESTREAMER_DATASET_DIR", "data/20scenarios")
|
| 18 |
+
hf_repo = os.environ.get("SCENESTREAMER_HF_REPO", DEFAULT_HF_REPO)
|
| 19 |
+
hf_file = os.environ.get("SCENESTREAMER_HF_FILE", DEFAULT_HF_FILE)
|
| 20 |
+
ckpt = os.environ.get("SCENESTREAMER_CKPT") or None
|
| 21 |
+
device = os.environ.get("SCENESTREAMER_DEVICE", "cpu")
|
| 22 |
+
|
| 23 |
+
if not Path(dataset_dir).exists():
|
| 24 |
+
with gr.Blocks(title="SceneStreamer Space Setup") as demo:
|
| 25 |
+
gr.Markdown("## SceneStreamer Space Setup Required")
|
| 26 |
+
gr.Markdown(
|
| 27 |
+
"This Space needs a local ScenarioNet dataset directory before the interactive demo can start.\n\n"
|
| 28 |
+
f"Current `SCENESTREAMER_DATASET_DIR`: `{dataset_dir}`"
|
| 29 |
+
)
|
| 30 |
+
gr.Markdown(
|
| 31 |
+
"Set Space variables or attach storage, then restart the Space:\n"
|
| 32 |
+
"- `SCENESTREAMER_DATASET_DIR`\n"
|
| 33 |
+
"- `SCENESTREAMER_HF_REPO` (optional)\n"
|
| 34 |
+
"- `SCENESTREAMER_HF_FILE` (optional)\n"
|
| 35 |
+
"- `SCENESTREAMER_CKPT` (optional local checkpoint)\n"
|
| 36 |
+
"- `SCENESTREAMER_DEVICE` (default `cpu`)"
|
| 37 |
+
)
|
| 38 |
+
gr.Markdown(
|
| 39 |
+
"This repo is expected to include a tiny bundled demo subset under `data/20scenarios`. "
|
| 40 |
+
"If you are seeing this page after pushing the repo, the demo data was likely not uploaded."
|
| 41 |
+
)
|
| 42 |
+
return demo
|
| 43 |
+
|
| 44 |
+
return build_demo(
|
| 45 |
+
dataset_dir=dataset_dir,
|
| 46 |
+
hf_repo=hf_repo,
|
| 47 |
+
hf_file=hf_file,
|
| 48 |
+
ckpt=ckpt,
|
| 49 |
+
device=device,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
demo = _build_space_demo()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
if __name__ == "__main__":
|
| 57 |
+
demo.launch()
|
| 58 |
+
|
cfgs/motion_default.yaml
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
|
| 4 |
+
# Experiment related
|
| 5 |
+
exp_name: 'default'
|
| 6 |
+
seed: 0
|
| 7 |
+
epochs: 50
|
| 8 |
+
batch_size: 10
|
| 9 |
+
val_batch_size: 4
|
| 10 |
+
num_workers: 16
|
| 11 |
+
val_num_workers: 16
|
| 12 |
+
num_sanity_val_steps: 100
|
| 13 |
+
val_check_interval: 1.0
|
| 14 |
+
wandb: False
|
| 15 |
+
log_dir: Null
|
| 16 |
+
limit_train_batches: -1
|
| 17 |
+
limit_val_batches: -1
|
| 18 |
+
prefetch_factor: 2
|
| 19 |
+
ckpt: Null
|
| 20 |
+
eval: False
|
| 21 |
+
pretrain: Null
|
| 22 |
+
deterministic: False
|
| 23 |
+
detect_anomaly: False
|
| 24 |
+
check_val_every_n_epoch: 1
|
| 25 |
+
|
| 26 |
+
USE_RL_FINETUNING: False
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Turn both on when training TG to match TrafficGen's metrics.
|
| 30 |
+
LIMIT_MAP_RANGE: False
|
| 31 |
+
FOLLOW_TRAFFICGEN: False
|
| 32 |
+
FORCE_SDC_FOR_TRAFFICGEN: False
|
| 33 |
+
ONLY_LANE_FOR_TRAFFICGEN: False
|
| 34 |
+
|
| 35 |
+
# True then agent info will not be used in encoder,
|
| 36 |
+
# and new tokens for history will be added for decoder.
|
| 37 |
+
GPT_STYLE: false
|
| 38 |
+
REMOVE_AGENT_FROM_SCENE_ENCODER: false
|
| 39 |
+
USE_DIFFUSION: false
|
| 40 |
+
USE_ADALN: null
|
| 41 |
+
BACKWARD_PREDICTION: false
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
USE_DESTINATION: false
|
| 45 |
+
|
| 46 |
+
TF_DEST: True
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
ADD_CONTOUR_RELATION: false
|
| 50 |
+
|
| 51 |
+
DELTA_TOKENIZER_FILE_NAME: "1030_argsort_less_256_128_128.pkl"
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
USE_TRAFFICGEN: false
|
| 55 |
+
TRAIN_TRAFFICGEN: null
|
| 56 |
+
USE_MOTION: true
|
| 57 |
+
|
| 58 |
+
EVAL_MOTION: true
|
| 59 |
+
EVAL_TRAFFICGEN: false
|
| 60 |
+
|
| 61 |
+
DELTA_POS_IS_VELOCITY: false
|
| 62 |
+
|
| 63 |
+
SIMPLE_RELATION: false
|
| 64 |
+
SIMPLE_RELATION_FACTOR: 1
|
| 65 |
+
|
| 66 |
+
RECONSTRUCT_MAP: false
|
| 67 |
+
|
| 68 |
+
UPDATE_RELATION: false
|
| 69 |
+
|
| 70 |
+
REMOVE_REL_NORM: false
|
| 71 |
+
|
| 72 |
+
DATA:
|
| 73 |
+
TRAINING_DATA_DIR: '/data/datasets/scenarionet/waymo/training'
|
| 74 |
+
TEST_DATA_DIR: '/data/datasets/scenarionet/waymo/validation'
|
| 75 |
+
ADV_INFO_PATH: 'data/all_adv.pkl'
|
| 76 |
+
SAMPLE_INTERVAL_TRAINING: 1
|
| 77 |
+
SAMPLE_INTERVAL_TEST: 1
|
| 78 |
+
SD_PASSTHROUGH: false
|
| 79 |
+
ALLOW_CACHE: false
|
| 80 |
+
RETURN_HALFWAY: false # Only used when generating LMDB dataset
|
| 81 |
+
USE_LMDB: false
|
| 82 |
+
USE_CACHE: false
|
| 83 |
+
|
| 84 |
+
PREPROCESSING:
|
| 85 |
+
MAX_VECTORS: 128
|
| 86 |
+
MAX_MAP_FEATURES: 512
|
| 87 |
+
MAX_LENGTH_PER_MAP_FEATURE: 10000 # Useless
|
| 88 |
+
MAX_AGENTS: 128
|
| 89 |
+
MAX_TRAFFIC_LIGHTS: 64
|
| 90 |
+
PADDING_TO_MAX: false
|
| 91 |
+
keep_all_data: false # for debug
|
| 92 |
+
ADD_SDC_TO_OBJECT_OF_INTEREST: true # Should be True when WOSAC
|
| 93 |
+
REMOVE_TRAFFIC_LIGHT_STATE: false
|
| 94 |
+
TRUNCATE_TIME: -1
|
| 95 |
+
|
| 96 |
+
TRAINING:
|
| 97 |
+
PREDICT_ALL_AGENTS: true
|
| 98 |
+
|
| 99 |
+
EVALUATION:
|
| 100 |
+
NAME: 'waymo_motion_prediction'
|
| 101 |
+
PREDICT_ALL_AGENTS: false
|
| 102 |
+
DELETE_EVAL_RESULT: true
|
| 103 |
+
NUM_MODES: 6
|
| 104 |
+
MAXIMUM_BATCH_SIZE: 10000
|
| 105 |
+
USE_CACHE: true
|
| 106 |
+
USE_TG_AS_GT: 1111
|
| 107 |
+
TG_REJECT_SAMPLING: True
|
| 108 |
+
TG_SDC_DISTANCE_MASKING: False
|
| 109 |
+
|
| 110 |
+
MODEL:
|
| 111 |
+
NAME: 'motionlm'
|
| 112 |
+
D_MODEL: 256
|
| 113 |
+
NUM_ATTN_LAYERS: 4
|
| 114 |
+
NUM_ATTN_HEAD: 8
|
| 115 |
+
# DROPOUT_OF_ATTN: 0.0
|
| 116 |
+
DROPOUT: 0.0
|
| 117 |
+
NUM_DECODER_LAYERS: 6
|
| 118 |
+
ADD_PE_FOR_TOKEN: true
|
| 119 |
+
RELATIVE_PE: true
|
| 120 |
+
RELATIVE_PE_DECODER: false
|
| 121 |
+
PRE_PROJECTION: false
|
| 122 |
+
KNN: 128
|
| 123 |
+
S2S_DISTANCE: null
|
| 124 |
+
SELF_ATTN_KNN: 128
|
| 125 |
+
CROSS_ATTN_KNN: 128
|
| 126 |
+
RANDOMIZE_AGENT_ID: true
|
| 127 |
+
A2S_KNN: null
|
| 128 |
+
A2S_DISTANCE: null
|
| 129 |
+
A2A_KNN: null
|
| 130 |
+
A2A_DISTANCE: null
|
| 131 |
+
ADD_RELATION_TO_V: false
|
| 132 |
+
IS_V7: False
|
| 133 |
+
PER_CONTOUR_POINT_RELATION: null
|
| 134 |
+
|
| 135 |
+
TOKENIZATION:
|
| 136 |
+
TOKENIZATION_METHOD: delta_delta
|
| 137 |
+
NUM_SKIPPED_STEPS: 5
|
| 138 |
+
NUM_BINS: 13
|
| 139 |
+
X_MAX: 3.5 # <<< Deprecated
|
| 140 |
+
X_MIN: -3.5 # <<< Deprecated
|
| 141 |
+
Y_MAX: 3.5 # <<< Deprecated
|
| 142 |
+
Y_MIN: -3.5 # <<< Deprecated
|
| 143 |
+
ADD_NOISE: false
|
| 144 |
+
NOISE_TOPK: 5
|
| 145 |
+
ALLOW_SKIP_STEP: True
|
| 146 |
+
|
| 147 |
+
MIN_DISPLACEMENT: 0.1
|
| 148 |
+
MIN_DISPLACEMENT_INIT: null
|
| 149 |
+
MIN_SPEED: null
|
| 150 |
+
SMOOTH_FACTOR: null
|
| 151 |
+
MAX_HEADING_DIFF: null
|
| 152 |
+
|
| 153 |
+
USE_CONTOUR_ERROR: True
|
| 154 |
+
|
| 155 |
+
VEH_LIMIT: 3.5
|
| 156 |
+
PED_LIMIT: 3.5
|
| 157 |
+
CYC_LIMIT: 3.5
|
| 158 |
+
|
| 159 |
+
FLIP_WRONG_HEADING: false
|
| 160 |
+
SHOULD_STANDARDIZE: true
|
| 161 |
+
|
| 162 |
+
# MIN_DISPLACEMENT: 0.3
|
| 163 |
+
# MIN_DISPLACEMENT_INIT: 1.0
|
| 164 |
+
# MIN_SPEED: 0.5
|
| 165 |
+
# SMOOTH_FACTOR: null
|
| 166 |
+
# MAX_HEADING_DIFF: 0.3
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
SAMPLING:
|
| 170 |
+
SAMPLING_METHOD: 'topp'
|
| 171 |
+
TEMPERATURE: 1.0
|
| 172 |
+
TOPP: 0.95
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
OPTIMIZATION:
|
| 176 |
+
# NUM_EPOCHS: 50
|
| 177 |
+
OPTIMIZER: AdamW
|
| 178 |
+
LR: 0.0003
|
| 179 |
+
WEIGHT_DECAY: 0.0
|
| 180 |
+
GRAD_NORM_CLIP: 1.0
|
| 181 |
+
SCHEDULER: cosine
|
| 182 |
+
WARMUP_STEPS: 2000
|
| 183 |
+
# TRAINING_STEPS: 300000
|
| 184 |
+
USE_FOCAL_LOSS: false
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
SUBMISSION:
|
| 188 |
+
GENERATE_SUBMISSION: false
|
| 189 |
+
PREFIX: "peng"
|
| 190 |
+
ACCOUNT: "dr.zhenghao.peng@gmail.com"
|
| 191 |
+
METHOD_NAME: "peng"
|
| 192 |
+
num_model_parameters: '10m' # TODO: Need to be changed accordingly!
|
| 193 |
+
SAVE_EVAL_DATA: true
|
| 194 |
+
|
| 195 |
+
TMP_DIR: "tmp" # Relative to repo root
|
| 196 |
+
|
| 197 |
+
ACTION_LABEL:
|
| 198 |
+
USE_ACTION_LABEL: false # Only valid for turning + acceleration
|
| 199 |
+
USE_SAFETY_LABEL: false
|
| 200 |
+
MASK_PROBABILITY_ACTION_LABEL: 0.0 # Might turn it on
|
| 201 |
+
MASK_PROBABILITY_SAFETY_LABEL: 0.0 # Might turn it on
|
| 202 |
+
|
| 203 |
+
LANGUAGE_CONDITION: false
|
| 204 |
+
FINE_TUNE_BERT: false
|
| 205 |
+
|
| 206 |
+
MCTS:
|
| 207 |
+
USE_MCTS: False
|
| 208 |
+
MCTS_DEPTH: -1
|
| 209 |
+
MCTS_WIDTH: -1
|
| 210 |
+
|
| 211 |
+
TOKEN_BUFFER_CACHE_LENGTH: 100
|
cfgs/scenestreamer-base-large.yaml
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- motion_default
|
| 3 |
+
- _self_
|
| 4 |
+
|
| 5 |
+
exp_name: 'scenestreamer-base-large'
|
| 6 |
+
pretrain: "/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/20250506_scenestreamer_v18_notg_large_FIXETYPE_2025-05-06/checkpoints"
|
| 7 |
+
|
| 8 |
+
num_workers: 8
|
| 9 |
+
val_num_workers: 8
|
| 10 |
+
num_sanity_val_steps: 10
|
| 11 |
+
|
| 12 |
+
batch_size: 4
|
| 13 |
+
val_batch_size: 4
|
| 14 |
+
limit_val_batches: -1
|
| 15 |
+
|
| 16 |
+
eval_backward_model: False
|
| 17 |
+
|
| 18 |
+
epochs: 30
|
| 19 |
+
wandb: True
|
| 20 |
+
log_dir: /bigdata/zhenghao/scenestreamer
|
| 21 |
+
|
| 22 |
+
SCENESTREAMER_ATTENTION_KNN: 128
|
| 23 |
+
SCENESTREAMER_ATTENTION_MAX_DISTANCE: 50
|
| 24 |
+
SCENESTREAMER_NO_TG: true
|
| 25 |
+
|
| 26 |
+
REMOVE_AGENT_FROM_SCENE_ENCODER: True # <<<
|
| 27 |
+
|
| 28 |
+
BACKWARD_PREDICTION: False # <<<
|
| 29 |
+
ADD_CONTOUR_RELATION: True # <<<
|
| 30 |
+
|
| 31 |
+
DELTA_POS_IS_VELOCITY: True
|
| 32 |
+
SIMPLE_RELATION: True
|
| 33 |
+
|
| 34 |
+
RECONSTRUCT_MAP: False
|
| 35 |
+
UPDATE_RELATION: False
|
| 36 |
+
REMOVE_REL_NORM: False # <<<
|
| 37 |
+
|
| 38 |
+
USE_TRAFFICGEN: True
|
| 39 |
+
USE_MOTION: True
|
| 40 |
+
EVAL_MOTION: True
|
| 41 |
+
EVAL_TRAFFICGEN: False
|
| 42 |
+
|
| 43 |
+
GPT_STYLE: True # <<<
|
| 44 |
+
USE_ADALN: False
|
| 45 |
+
|
| 46 |
+
SAMPLING:
|
| 47 |
+
TOPP: 0.95
|
| 48 |
+
TEMPERATURE: 1.0
|
| 49 |
+
|
| 50 |
+
TOKENIZATION:
|
| 51 |
+
TOKENIZATION_METHOD: "BicycleModelTokenizerFixed0124" # <<<
|
| 52 |
+
USE_CONTOUR_ERROR: True # <<<
|
| 53 |
+
ALLOW_SKIP_STEP: True
|
| 54 |
+
ADD_NOISE: False
|
| 55 |
+
NUM_BINS: 33
|
| 56 |
+
|
| 57 |
+
PREPROCESSING:
|
| 58 |
+
REMOVE_TRAFFIC_LIGHT_STATE: False
|
| 59 |
+
MAX_LENGTH_PER_MAP_FEATURE: 10
|
| 60 |
+
MAX_MAP_FEATURES: 3000
|
| 61 |
+
MAX_VECTORS: 30
|
| 62 |
+
MAX_AGENTS: 128
|
| 63 |
+
DEST_DROPOUT: 0.0
|
| 64 |
+
ADD_SDC_TO_OBJECT_OF_INTEREST: False
|
| 65 |
+
|
| 66 |
+
DATA:
|
| 67 |
+
TRAINING_DATA_DIR: ''
|
| 68 |
+
TEST_DATA_DIR: ''
|
| 69 |
+
|
| 70 |
+
MODEL:
|
| 71 |
+
USE_MOTION_HEAD_PRENORM: True
|
| 72 |
+
ALL_TO_MAP_3D: False
|
| 73 |
+
D_MODEL: 128
|
| 74 |
+
NAME: 'scenestreamer'
|
| 75 |
+
NUM_ATTN_HEAD: 4 # TODO: Need to increase? was 8
|
| 76 |
+
# Encoder:
|
| 77 |
+
NUM_ATTN_LAYERS: 2 # TODO: Need to increase? was 3
|
| 78 |
+
RELATIVE_PE: true
|
| 79 |
+
# Decoder:
|
| 80 |
+
NUM_DECODER_LAYERS: 4 # TODO: Need to increase? was 6
|
| 81 |
+
RELATIVE_PE_DECODER: True
|
| 82 |
+
SIMPLE_RELATION_FACTOR: 1
|
| 83 |
+
# New:
|
| 84 |
+
KNN: -100
|
| 85 |
+
S2S_DISTANCE: -100
|
| 86 |
+
A2S_KNN: -100
|
| 87 |
+
A2S_DISTANCE: -100
|
| 88 |
+
A2A_KNN: -100
|
| 89 |
+
A2A_DISTANCE: -100
|
| 90 |
+
ADD_RELATION_TO_V: False
|
| 91 |
+
PER_CONTOUR_POINT_RELATION: False
|
| 92 |
+
IS_V7: True
|
| 93 |
+
|
| 94 |
+
SUBMISSION:
|
| 95 |
+
METHOD_NAME: "scenestreamer-base-large"
|
| 96 |
+
num_model_parameters: '3.3m'
|
cfgs/scenestreamer-base-small.yaml
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- motion_default
|
| 3 |
+
- _self_
|
| 4 |
+
|
| 5 |
+
exp_name: 'scenestreamer-base-small'
|
| 6 |
+
pretrain: "/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/20250507_scenestreamer_v17_notg_finetune_FIXTYPE_2025-05-07/checkpoints"
|
| 7 |
+
|
| 8 |
+
num_workers: 8
|
| 9 |
+
val_num_workers: 8
|
| 10 |
+
num_sanity_val_steps: 10
|
| 11 |
+
|
| 12 |
+
batch_size: 4
|
| 13 |
+
val_batch_size: 4
|
| 14 |
+
limit_val_batches: -1
|
| 15 |
+
|
| 16 |
+
eval_backward_model: False
|
| 17 |
+
|
| 18 |
+
epochs: 30
|
| 19 |
+
wandb: True
|
| 20 |
+
log_dir: /bigdata/zhenghao/scenestreamer
|
| 21 |
+
|
| 22 |
+
SCENESTREAMER_ATTENTION_KNN: 128
|
| 23 |
+
SCENESTREAMER_ATTENTION_MAX_DISTANCE: 50
|
| 24 |
+
SCENESTREAMER_NO_TG: true
|
| 25 |
+
|
| 26 |
+
REMOVE_AGENT_FROM_SCENE_ENCODER: True # <<<
|
| 27 |
+
|
| 28 |
+
BACKWARD_PREDICTION: False # <<<
|
| 29 |
+
ADD_CONTOUR_RELATION: True # <<<
|
| 30 |
+
|
| 31 |
+
DELTA_POS_IS_VELOCITY: True
|
| 32 |
+
SIMPLE_RELATION: True
|
| 33 |
+
|
| 34 |
+
RECONSTRUCT_MAP: False
|
| 35 |
+
UPDATE_RELATION: False
|
| 36 |
+
REMOVE_REL_NORM: False # <<<
|
| 37 |
+
|
| 38 |
+
USE_TRAFFICGEN: True
|
| 39 |
+
USE_MOTION: True
|
| 40 |
+
EVAL_MOTION: True
|
| 41 |
+
EVAL_TRAFFICGEN: False
|
| 42 |
+
|
| 43 |
+
GPT_STYLE: True # <<<
|
| 44 |
+
USE_ADALN: False
|
| 45 |
+
|
| 46 |
+
SAMPLING:
|
| 47 |
+
TOPP: 0.95
|
| 48 |
+
TEMPERATURE: 1.0
|
| 49 |
+
|
| 50 |
+
TOKENIZATION:
|
| 51 |
+
TOKENIZATION_METHOD: "BicycleModelTokenizerFixed0124" # <<<
|
| 52 |
+
USE_CONTOUR_ERROR: True # <<<
|
| 53 |
+
ALLOW_SKIP_STEP: True
|
| 54 |
+
ADD_NOISE: False
|
| 55 |
+
NUM_BINS: 33
|
| 56 |
+
|
| 57 |
+
PREPROCESSING:
|
| 58 |
+
REMOVE_TRAFFIC_LIGHT_STATE: False
|
| 59 |
+
MAX_LENGTH_PER_MAP_FEATURE: 10
|
| 60 |
+
MAX_MAP_FEATURES: 3000
|
| 61 |
+
MAX_VECTORS: 30
|
| 62 |
+
MAX_AGENTS: 128
|
| 63 |
+
DEST_DROPOUT: 0.0
|
| 64 |
+
ADD_SDC_TO_OBJECT_OF_INTEREST: False
|
| 65 |
+
|
| 66 |
+
DATA:
|
| 67 |
+
TRAINING_DATA_DIR: ''
|
| 68 |
+
TEST_DATA_DIR: ''
|
| 69 |
+
|
| 70 |
+
MODEL:
|
| 71 |
+
USE_MOTION_HEAD_PRENORM: True
|
| 72 |
+
ALL_TO_MAP_3D: False
|
| 73 |
+
D_MODEL: 64
|
| 74 |
+
NAME: 'scenestreamer'
|
| 75 |
+
NUM_ATTN_HEAD: 4 # TODO: Need to increase? was 8
|
| 76 |
+
# Encoder:
|
| 77 |
+
NUM_ATTN_LAYERS: 2 # TODO: Need to increase? was 3
|
| 78 |
+
RELATIVE_PE: true
|
| 79 |
+
# Decoder:
|
| 80 |
+
NUM_DECODER_LAYERS: 4 # TODO: Need to increase? was 6
|
| 81 |
+
RELATIVE_PE_DECODER: True
|
| 82 |
+
SIMPLE_RELATION_FACTOR: 1
|
| 83 |
+
# New:
|
| 84 |
+
KNN: -100
|
| 85 |
+
S2S_DISTANCE: -100
|
| 86 |
+
A2S_KNN: -100
|
| 87 |
+
A2S_DISTANCE: -100
|
| 88 |
+
A2A_KNN: -100
|
| 89 |
+
A2A_DISTANCE: -100
|
| 90 |
+
ADD_RELATION_TO_V: False
|
| 91 |
+
PER_CONTOUR_POINT_RELATION: False
|
| 92 |
+
IS_V7: True
|
| 93 |
+
|
| 94 |
+
SUBMISSION:
|
| 95 |
+
METHOD_NAME: "scenestreamer-base-small"
|
| 96 |
+
num_model_parameters: '1.1m'
|
cfgs/scenestreamer-base-xl.yaml
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- motion_default
|
| 3 |
+
- _self_
|
| 4 |
+
|
| 5 |
+
exp_name: 'scenestreamer-base-xl'
|
| 6 |
+
pretrain: "/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/20250512_scenestreamer-base-xl_2025-05-12/checkpoints"
|
| 7 |
+
|
| 8 |
+
num_workers: 8
|
| 9 |
+
val_num_workers: 8
|
| 10 |
+
num_sanity_val_steps: 10
|
| 11 |
+
|
| 12 |
+
batch_size: 4
|
| 13 |
+
val_batch_size: 4
|
| 14 |
+
limit_val_batches: -1
|
| 15 |
+
|
| 16 |
+
eval_backward_model: False
|
| 17 |
+
|
| 18 |
+
epochs: 30
|
| 19 |
+
wandb: True
|
| 20 |
+
log_dir: /bigdata/zhenghao/scenestreamer
|
| 21 |
+
|
| 22 |
+
SCENESTREAMER_ATTENTION_KNN: 128
|
| 23 |
+
SCENESTREAMER_ATTENTION_MAX_DISTANCE: 50
|
| 24 |
+
SCENESTREAMER_NO_TG: true
|
| 25 |
+
|
| 26 |
+
REMOVE_AGENT_FROM_SCENE_ENCODER: True # <<<
|
| 27 |
+
|
| 28 |
+
BACKWARD_PREDICTION: False # <<<
|
| 29 |
+
ADD_CONTOUR_RELATION: True # <<<
|
| 30 |
+
|
| 31 |
+
DELTA_POS_IS_VELOCITY: True
|
| 32 |
+
SIMPLE_RELATION: True
|
| 33 |
+
|
| 34 |
+
RECONSTRUCT_MAP: False
|
| 35 |
+
UPDATE_RELATION: False
|
| 36 |
+
REMOVE_REL_NORM: False # <<<
|
| 37 |
+
|
| 38 |
+
USE_TRAFFICGEN: True
|
| 39 |
+
USE_MOTION: True
|
| 40 |
+
EVAL_MOTION: True
|
| 41 |
+
EVAL_TRAFFICGEN: False
|
| 42 |
+
|
| 43 |
+
GPT_STYLE: True # <<<
|
| 44 |
+
USE_ADALN: False
|
| 45 |
+
|
| 46 |
+
SAMPLING:
|
| 47 |
+
TOPP: 0.95
|
| 48 |
+
TEMPERATURE: 1.0
|
| 49 |
+
|
| 50 |
+
TOKENIZATION:
|
| 51 |
+
TOKENIZATION_METHOD: "BicycleModelTokenizerFixed0124" # <<<
|
| 52 |
+
USE_CONTOUR_ERROR: True # <<<
|
| 53 |
+
ALLOW_SKIP_STEP: True
|
| 54 |
+
ADD_NOISE: False
|
| 55 |
+
NUM_BINS: 33
|
| 56 |
+
|
| 57 |
+
PREPROCESSING:
|
| 58 |
+
REMOVE_TRAFFIC_LIGHT_STATE: False
|
| 59 |
+
MAX_LENGTH_PER_MAP_FEATURE: 10
|
| 60 |
+
MAX_MAP_FEATURES: 3000
|
| 61 |
+
MAX_VECTORS: 30
|
| 62 |
+
MAX_AGENTS: 128
|
| 63 |
+
DEST_DROPOUT: 0.0
|
| 64 |
+
ADD_SDC_TO_OBJECT_OF_INTEREST: False
|
| 65 |
+
|
| 66 |
+
DATA:
|
| 67 |
+
TRAINING_DATA_DIR: ''
|
| 68 |
+
TEST_DATA_DIR: ''
|
| 69 |
+
|
| 70 |
+
MODEL:
|
| 71 |
+
USE_MOTION_HEAD_PRENORM: True
|
| 72 |
+
ALL_TO_MAP_3D: False
|
| 73 |
+
D_MODEL: 128
|
| 74 |
+
NAME: 'scenestreamer'
|
| 75 |
+
NUM_ATTN_HEAD: 8
|
| 76 |
+
# Encoder:
|
| 77 |
+
NUM_ATTN_LAYERS: 3
|
| 78 |
+
RELATIVE_PE: true
|
| 79 |
+
# Decoder:
|
| 80 |
+
NUM_DECODER_LAYERS: 6
|
| 81 |
+
RELATIVE_PE_DECODER: True
|
| 82 |
+
SIMPLE_RELATION_FACTOR: 1
|
| 83 |
+
# New:
|
| 84 |
+
KNN: -100
|
| 85 |
+
S2S_DISTANCE: -100
|
| 86 |
+
A2S_KNN: -100
|
| 87 |
+
A2S_DISTANCE: -100
|
| 88 |
+
A2A_KNN: -100
|
| 89 |
+
A2A_DISTANCE: -100
|
| 90 |
+
ADD_RELATION_TO_V: False
|
| 91 |
+
PER_CONTOUR_POINT_RELATION: False
|
| 92 |
+
IS_V7: True
|
| 93 |
+
|
| 94 |
+
SUBMISSION:
|
| 95 |
+
METHOD_NAME: "scenestreamer-base-xl"
|
| 96 |
+
num_model_parameters: '4.2m'
|
| 97 |
+
ACCOUNT: "dr.zhenghao.peng@gmail.com"
|
cfgs/scenestreamer-full-large-nors.yaml
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- motion_default
|
| 3 |
+
- _self_
|
| 4 |
+
|
| 5 |
+
exp_name: 'scenestreamer-full-large-nors'
|
| 6 |
+
pretrain: "/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/20250507_scenestreamer-full-large_2025-05-07/checkpoints/20250507_scenestreamer-full-large_2025-05-07_epoch=1-step=77031.ckpt"
|
| 7 |
+
|
| 8 |
+
num_workers: 8
|
| 9 |
+
val_num_workers: 8
|
| 10 |
+
num_sanity_val_steps: 10
|
| 11 |
+
|
| 12 |
+
batch_size: 4
|
| 13 |
+
val_batch_size: 4
|
| 14 |
+
limit_val_batches: -1
|
| 15 |
+
|
| 16 |
+
eval_backward_model: False
|
| 17 |
+
|
| 18 |
+
epochs: 30
|
| 19 |
+
wandb: True
|
| 20 |
+
log_dir: /bigdata/zhenghao/scenestreamer
|
| 21 |
+
|
| 22 |
+
SCENESTREAMER_ATTENTION_KNN: 128
|
| 23 |
+
SCENESTREAMER_ATTENTION_MAX_DISTANCE: 50
|
| 24 |
+
SCENESTREAMER_NO_TG: false
|
| 25 |
+
|
| 26 |
+
REMOVE_AGENT_FROM_SCENE_ENCODER: True # <<<
|
| 27 |
+
|
| 28 |
+
BACKWARD_PREDICTION: False # <<<
|
| 29 |
+
ADD_CONTOUR_RELATION: True # <<<
|
| 30 |
+
|
| 31 |
+
DELTA_POS_IS_VELOCITY: True
|
| 32 |
+
SIMPLE_RELATION: True
|
| 33 |
+
|
| 34 |
+
RECONSTRUCT_MAP: False
|
| 35 |
+
UPDATE_RELATION: False
|
| 36 |
+
REMOVE_REL_NORM: False # <<<
|
| 37 |
+
|
| 38 |
+
USE_TRAFFICGEN: True
|
| 39 |
+
USE_MOTION: True
|
| 40 |
+
EVAL_MOTION: True
|
| 41 |
+
EVAL_TRAFFICGEN: False
|
| 42 |
+
|
| 43 |
+
GPT_STYLE: True # <<<
|
| 44 |
+
USE_ADALN: False
|
| 45 |
+
|
| 46 |
+
SAMPLING:
|
| 47 |
+
TOPP: 0.95
|
| 48 |
+
TEMPERATURE: 1.0
|
| 49 |
+
|
| 50 |
+
TOKENIZATION:
|
| 51 |
+
TOKENIZATION_METHOD: "BicycleModelTokenizerFixed0124" # <<<
|
| 52 |
+
USE_CONTOUR_ERROR: True # <<<
|
| 53 |
+
ALLOW_SKIP_STEP: True
|
| 54 |
+
ADD_NOISE: False
|
| 55 |
+
NUM_BINS: 33
|
| 56 |
+
|
| 57 |
+
PREPROCESSING:
|
| 58 |
+
REMOVE_TRAFFIC_LIGHT_STATE: False
|
| 59 |
+
MAX_LENGTH_PER_MAP_FEATURE: 10
|
| 60 |
+
MAX_MAP_FEATURES: 3000
|
| 61 |
+
MAX_VECTORS: 30
|
| 62 |
+
MAX_AGENTS: 128
|
| 63 |
+
DEST_DROPOUT: 0.0
|
| 64 |
+
ADD_SDC_TO_OBJECT_OF_INTEREST: False
|
| 65 |
+
|
| 66 |
+
DATA:
|
| 67 |
+
TRAINING_DATA_DIR: ''
|
| 68 |
+
TEST_DATA_DIR: ''
|
| 69 |
+
|
| 70 |
+
MODEL:
|
| 71 |
+
USE_MOTION_HEAD_PRENORM: True
|
| 72 |
+
ALL_TO_MAP_3D: False
|
| 73 |
+
D_MODEL: 128
|
| 74 |
+
NAME: 'scenestreamer'
|
| 75 |
+
NUM_ATTN_HEAD: 4 # TODO: Need to increase? was 8
|
| 76 |
+
# Encoder:
|
| 77 |
+
NUM_ATTN_LAYERS: 2 # TODO: Need to increase? was 3
|
| 78 |
+
RELATIVE_PE: true
|
| 79 |
+
# Decoder:
|
| 80 |
+
NUM_DECODER_LAYERS: 4 # TODO: Need to increase? was 6
|
| 81 |
+
RELATIVE_PE_DECODER: True
|
| 82 |
+
SIMPLE_RELATION_FACTOR: 1
|
| 83 |
+
# New:
|
| 84 |
+
KNN: -100
|
| 85 |
+
S2S_DISTANCE: -100
|
| 86 |
+
A2S_KNN: -100
|
| 87 |
+
A2S_DISTANCE: -100
|
| 88 |
+
A2A_KNN: -100
|
| 89 |
+
A2A_DISTANCE: -100
|
| 90 |
+
ADD_RELATION_TO_V: False
|
| 91 |
+
PER_CONTOUR_POINT_RELATION: False
|
| 92 |
+
IS_V7: True
|
| 93 |
+
|
| 94 |
+
SUBMISSION:
|
| 95 |
+
METHOD_NAME: "scenestreamer-full-large"
|
| 96 |
+
num_model_parameters: '4.6m'
|
| 97 |
+
|
| 98 |
+
EVALUATION:
|
| 99 |
+
TG_REJECT_SAMPLING: False
|
cfgs/scenestreamer-full-large.yaml
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- motion_default
|
| 3 |
+
- _self_
|
| 4 |
+
|
| 5 |
+
exp_name: 'scenestreamer-full-large'
|
| 6 |
+
pretrain: "/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/20250507_scenestreamer-full-large_2025-05-07/checkpoints/20250507_scenestreamer-full-large_2025-05-07_epoch=1-step=77031.ckpt"
|
| 7 |
+
|
| 8 |
+
num_workers: 8
|
| 9 |
+
val_num_workers: 8
|
| 10 |
+
num_sanity_val_steps: 10
|
| 11 |
+
|
| 12 |
+
batch_size: 4
|
| 13 |
+
val_batch_size: 4
|
| 14 |
+
limit_val_batches: -1
|
| 15 |
+
|
| 16 |
+
eval_backward_model: False
|
| 17 |
+
|
| 18 |
+
epochs: 30
|
| 19 |
+
wandb: True
|
| 20 |
+
log_dir: /bigdata/zhenghao/scenestreamer
|
| 21 |
+
|
| 22 |
+
SCENESTREAMER_ATTENTION_KNN: 128
|
| 23 |
+
SCENESTREAMER_ATTENTION_MAX_DISTANCE: 50
|
| 24 |
+
SCENESTREAMER_NO_TG: false
|
| 25 |
+
|
| 26 |
+
REMOVE_AGENT_FROM_SCENE_ENCODER: True # <<<
|
| 27 |
+
|
| 28 |
+
BACKWARD_PREDICTION: False # <<<
|
| 29 |
+
ADD_CONTOUR_RELATION: True # <<<
|
| 30 |
+
|
| 31 |
+
DELTA_POS_IS_VELOCITY: True
|
| 32 |
+
SIMPLE_RELATION: True
|
| 33 |
+
|
| 34 |
+
RECONSTRUCT_MAP: False
|
| 35 |
+
UPDATE_RELATION: False
|
| 36 |
+
REMOVE_REL_NORM: False # <<<
|
| 37 |
+
|
| 38 |
+
USE_TRAFFICGEN: True
|
| 39 |
+
USE_MOTION: True
|
| 40 |
+
EVAL_MOTION: True
|
| 41 |
+
EVAL_TRAFFICGEN: False
|
| 42 |
+
|
| 43 |
+
GPT_STYLE: True # <<<
|
| 44 |
+
USE_ADALN: False
|
| 45 |
+
|
| 46 |
+
SAMPLING:
|
| 47 |
+
TOPP: 0.95
|
| 48 |
+
TEMPERATURE: 1.0
|
| 49 |
+
|
| 50 |
+
TOKENIZATION:
|
| 51 |
+
TOKENIZATION_METHOD: "BicycleModelTokenizerFixed0124" # <<<
|
| 52 |
+
USE_CONTOUR_ERROR: True # <<<
|
| 53 |
+
ALLOW_SKIP_STEP: True
|
| 54 |
+
ADD_NOISE: False
|
| 55 |
+
NUM_BINS: 33
|
| 56 |
+
|
| 57 |
+
PREPROCESSING:
|
| 58 |
+
REMOVE_TRAFFIC_LIGHT_STATE: False
|
| 59 |
+
MAX_LENGTH_PER_MAP_FEATURE: 10
|
| 60 |
+
MAX_MAP_FEATURES: 3000
|
| 61 |
+
MAX_VECTORS: 30
|
| 62 |
+
MAX_AGENTS: 128
|
| 63 |
+
DEST_DROPOUT: 0.0
|
| 64 |
+
ADD_SDC_TO_OBJECT_OF_INTEREST: False
|
| 65 |
+
|
| 66 |
+
DATA:
|
| 67 |
+
TRAINING_DATA_DIR: ''
|
| 68 |
+
TEST_DATA_DIR: ''
|
| 69 |
+
|
| 70 |
+
MODEL:
|
| 71 |
+
USE_MOTION_HEAD_PRENORM: True
|
| 72 |
+
ALL_TO_MAP_3D: False
|
| 73 |
+
D_MODEL: 128
|
| 74 |
+
NAME: 'scenestreamer'
|
| 75 |
+
NUM_ATTN_HEAD: 4 # TODO: Need to increase? was 8
|
| 76 |
+
# Encoder:
|
| 77 |
+
NUM_ATTN_LAYERS: 2 # TODO: Need to increase? was 3
|
| 78 |
+
RELATIVE_PE: true
|
| 79 |
+
# Decoder:
|
| 80 |
+
NUM_DECODER_LAYERS: 4 # TODO: Need to increase? was 6
|
| 81 |
+
RELATIVE_PE_DECODER: True
|
| 82 |
+
SIMPLE_RELATION_FACTOR: 1
|
| 83 |
+
# New:
|
| 84 |
+
KNN: -100
|
| 85 |
+
S2S_DISTANCE: -100
|
| 86 |
+
A2S_KNN: -100
|
| 87 |
+
A2S_DISTANCE: -100
|
| 88 |
+
A2A_KNN: -100
|
| 89 |
+
A2A_DISTANCE: -100
|
| 90 |
+
ADD_RELATION_TO_V: False
|
| 91 |
+
PER_CONTOUR_POINT_RELATION: False
|
| 92 |
+
IS_V7: True
|
| 93 |
+
|
| 94 |
+
SUBMISSION:
|
| 95 |
+
METHOD_NAME: "scenestreamer-full-large"
|
| 96 |
+
num_model_parameters: '4.6m'
|
| 97 |
+
|
| 98 |
+
EVALUATION:
|
| 99 |
+
TG_REJECT_SAMPLING: True
|
cfgs/scenestreamer-full-small.yaml
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- motion_default
|
| 3 |
+
- _self_
|
| 4 |
+
|
| 5 |
+
exp_name: 'scenestreamer-full-small'
|
| 6 |
+
pretrain: "/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/20250505_scenestreamer_v19_withtg_nodest_FIXEDAS_2025-05-05/checkpoints"
|
| 7 |
+
|
| 8 |
+
num_workers: 8
|
| 9 |
+
val_num_workers: 8
|
| 10 |
+
num_sanity_val_steps: 10
|
| 11 |
+
|
| 12 |
+
batch_size: 4
|
| 13 |
+
val_batch_size: 4
|
| 14 |
+
limit_val_batches: -1
|
| 15 |
+
|
| 16 |
+
eval_backward_model: False
|
| 17 |
+
|
| 18 |
+
epochs: 30
|
| 19 |
+
wandb: True
|
| 20 |
+
log_dir: /bigdata/zhenghao/scenestreamer
|
| 21 |
+
|
| 22 |
+
SCENESTREAMER_ATTENTION_KNN: 128
|
| 23 |
+
SCENESTREAMER_ATTENTION_MAX_DISTANCE: 50
|
| 24 |
+
SCENESTREAMER_NO_TG: false
|
| 25 |
+
|
| 26 |
+
REMOVE_AGENT_FROM_SCENE_ENCODER: True # <<<
|
| 27 |
+
|
| 28 |
+
BACKWARD_PREDICTION: False # <<<
|
| 29 |
+
ADD_CONTOUR_RELATION: True # <<<
|
| 30 |
+
|
| 31 |
+
DELTA_POS_IS_VELOCITY: True
|
| 32 |
+
SIMPLE_RELATION: True
|
| 33 |
+
|
| 34 |
+
RECONSTRUCT_MAP: False
|
| 35 |
+
UPDATE_RELATION: False
|
| 36 |
+
REMOVE_REL_NORM: False # <<<
|
| 37 |
+
|
| 38 |
+
USE_TRAFFICGEN: True
|
| 39 |
+
USE_MOTION: True
|
| 40 |
+
EVAL_MOTION: True
|
| 41 |
+
EVAL_TRAFFICGEN: False
|
| 42 |
+
|
| 43 |
+
GPT_STYLE: True # <<<
|
| 44 |
+
USE_ADALN: False
|
| 45 |
+
|
| 46 |
+
SAMPLING:
|
| 47 |
+
TOPP: 0.95
|
| 48 |
+
TEMPERATURE: 1.0
|
| 49 |
+
|
| 50 |
+
TOKENIZATION:
|
| 51 |
+
TOKENIZATION_METHOD: "BicycleModelTokenizerFixed0124" # <<<
|
| 52 |
+
USE_CONTOUR_ERROR: True # <<<
|
| 53 |
+
ALLOW_SKIP_STEP: True
|
| 54 |
+
ADD_NOISE: False
|
| 55 |
+
NUM_BINS: 33
|
| 56 |
+
|
| 57 |
+
PREPROCESSING:
|
| 58 |
+
REMOVE_TRAFFIC_LIGHT_STATE: False
|
| 59 |
+
MAX_LENGTH_PER_MAP_FEATURE: 10
|
| 60 |
+
MAX_MAP_FEATURES: 3000
|
| 61 |
+
MAX_VECTORS: 30
|
| 62 |
+
MAX_AGENTS: 128
|
| 63 |
+
DEST_DROPOUT: 0.0
|
| 64 |
+
ADD_SDC_TO_OBJECT_OF_INTEREST: False
|
| 65 |
+
|
| 66 |
+
DATA:
|
| 67 |
+
TRAINING_DATA_DIR: ''
|
| 68 |
+
TEST_DATA_DIR: ''
|
| 69 |
+
|
| 70 |
+
MODEL:
|
| 71 |
+
USE_MOTION_HEAD_PRENORM: True
|
| 72 |
+
ALL_TO_MAP_3D: False
|
| 73 |
+
D_MODEL: 64 # TODO: Need to increase? was 128
|
| 74 |
+
NAME: 'scenestreamer'
|
| 75 |
+
NUM_ATTN_HEAD: 4 # TODO: Need to increase? was 8
|
| 76 |
+
# Encoder:
|
| 77 |
+
NUM_ATTN_LAYERS: 2 # TODO: Need to increase? was 3
|
| 78 |
+
RELATIVE_PE: true
|
| 79 |
+
# Decoder:
|
| 80 |
+
NUM_DECODER_LAYERS: 4 # TODO: Need to increase? was 6
|
| 81 |
+
RELATIVE_PE_DECODER: True
|
| 82 |
+
SIMPLE_RELATION_FACTOR: 1
|
| 83 |
+
# New:
|
| 84 |
+
KNN: -100
|
| 85 |
+
S2S_DISTANCE: -100
|
| 86 |
+
A2S_KNN: -100
|
| 87 |
+
A2S_DISTANCE: -100
|
| 88 |
+
A2A_KNN: -100
|
| 89 |
+
A2A_DISTANCE: -100
|
| 90 |
+
ADD_RELATION_TO_V: False
|
| 91 |
+
PER_CONTOUR_POINT_RELATION: False
|
| 92 |
+
IS_V7: True
|
| 93 |
+
|
| 94 |
+
SUBMISSION:
|
| 95 |
+
METHOD_NAME: "scenestreamer-full-small"
|
| 96 |
+
num_model_parameters: '1.5m'
|
cfgs/scenestreamer-full-xl.yaml
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- motion_default
|
| 3 |
+
- _self_
|
| 4 |
+
|
| 5 |
+
exp_name: 'scenestreamer-full-xl'
|
| 6 |
+
pretrain: "/bigdata/zhenghao/scenestreamer/lightning_logs/scenestreamer/20250518_scenestreamer-full-xs_2025-05-18/checkpoints/20250518_scenestreamer-full-xs_2025-05-18_epoch=1-step=62494.ckpt"
|
| 7 |
+
|
| 8 |
+
num_workers: 8
|
| 9 |
+
val_num_workers: 8
|
| 10 |
+
num_sanity_val_steps: 10
|
| 11 |
+
|
| 12 |
+
batch_size: 4
|
| 13 |
+
val_batch_size: 4
|
| 14 |
+
limit_val_batches: -1
|
| 15 |
+
|
| 16 |
+
eval_backward_model: False
|
| 17 |
+
|
| 18 |
+
epochs: 30
|
| 19 |
+
wandb: True
|
| 20 |
+
log_dir: /bigdata/zhenghao/scenestreamer
|
| 21 |
+
|
| 22 |
+
SCENESTREAMER_ATTENTION_KNN: 128
|
| 23 |
+
SCENESTREAMER_ATTENTION_MAX_DISTANCE: 50
|
| 24 |
+
SCENESTREAMER_NO_TG: false
|
| 25 |
+
|
| 26 |
+
REMOVE_AGENT_FROM_SCENE_ENCODER: True # <<<
|
| 27 |
+
|
| 28 |
+
BACKWARD_PREDICTION: False # <<<
|
| 29 |
+
ADD_CONTOUR_RELATION: True # <<<
|
| 30 |
+
|
| 31 |
+
DELTA_POS_IS_VELOCITY: True
|
| 32 |
+
SIMPLE_RELATION: True
|
| 33 |
+
|
| 34 |
+
RECONSTRUCT_MAP: False
|
| 35 |
+
UPDATE_RELATION: False
|
| 36 |
+
REMOVE_REL_NORM: False # <<<
|
| 37 |
+
|
| 38 |
+
USE_TRAFFICGEN: True
|
| 39 |
+
USE_MOTION: True
|
| 40 |
+
EVAL_MOTION: True
|
| 41 |
+
EVAL_TRAFFICGEN: False
|
| 42 |
+
|
| 43 |
+
GPT_STYLE: True # <<<
|
| 44 |
+
USE_ADALN: False
|
| 45 |
+
|
| 46 |
+
SAMPLING:
|
| 47 |
+
TOPP: 0.95
|
| 48 |
+
TEMPERATURE: 1.0
|
| 49 |
+
|
| 50 |
+
TOKENIZATION:
|
| 51 |
+
TOKENIZATION_METHOD: "BicycleModelTokenizerFixed0124" # <<<
|
| 52 |
+
USE_CONTOUR_ERROR: True # <<<
|
| 53 |
+
ALLOW_SKIP_STEP: True
|
| 54 |
+
ADD_NOISE: False
|
| 55 |
+
NUM_BINS: 33
|
| 56 |
+
|
| 57 |
+
PREPROCESSING:
|
| 58 |
+
REMOVE_TRAFFIC_LIGHT_STATE: False
|
| 59 |
+
MAX_LENGTH_PER_MAP_FEATURE: 10
|
| 60 |
+
MAX_MAP_FEATURES: 3000
|
| 61 |
+
MAX_VECTORS: 30
|
| 62 |
+
MAX_AGENTS: 128
|
| 63 |
+
DEST_DROPOUT: 0.0
|
| 64 |
+
ADD_SDC_TO_OBJECT_OF_INTEREST: False
|
| 65 |
+
|
| 66 |
+
DATA:
|
| 67 |
+
TRAINING_DATA_DIR: ''
|
| 68 |
+
TEST_DATA_DIR: ''
|
| 69 |
+
|
| 70 |
+
MODEL:
|
| 71 |
+
USE_MOTION_HEAD_PRENORM: True
|
| 72 |
+
ALL_TO_MAP_3D: False
|
| 73 |
+
D_MODEL: 128
|
| 74 |
+
NAME: 'scenestreamer'
|
| 75 |
+
NUM_ATTN_HEAD: 8
|
| 76 |
+
# Encoder:
|
| 77 |
+
NUM_ATTN_LAYERS: 3
|
| 78 |
+
RELATIVE_PE: true
|
| 79 |
+
# Decoder:
|
| 80 |
+
NUM_DECODER_LAYERS: 6
|
| 81 |
+
RELATIVE_PE_DECODER: True
|
| 82 |
+
SIMPLE_RELATION_FACTOR: 1
|
| 83 |
+
# New:
|
| 84 |
+
KNN: -100
|
| 85 |
+
S2S_DISTANCE: -100
|
| 86 |
+
A2S_KNN: -100
|
| 87 |
+
A2S_DISTANCE: -100
|
| 88 |
+
A2A_KNN: -100
|
| 89 |
+
A2A_DISTANCE: -100
|
| 90 |
+
ADD_RELATION_TO_V: False
|
| 91 |
+
PER_CONTOUR_POINT_RELATION: False
|
| 92 |
+
IS_V7: True
|
| 93 |
+
|
| 94 |
+
SUBMISSION:
|
| 95 |
+
METHOD_NAME: "scenestreamer-full-xl"
|
| 96 |
+
num_model_parameters: '5.5m'
|
| 97 |
+
ACCOUNT: "dr.zhenghao.peng@gmail.com"
|
| 98 |
+
|
| 99 |
+
EVALUATION:
|
| 100 |
+
TG_REJECT_SAMPLING: True
|
data/20scenarios/cache/sd_training.tfrecord-00000-of-01000_101d4e5775093d0c.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d2f5a926ba159d4e9acec464c9d091c093d138a21796ce5c264fea7f4398a777
|
| 3 |
+
size 3007314
|
data/20scenarios/cache/sd_training.tfrecord-00000-of-01000_16d4f837fb57ff1.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b8a001efb1d464a0f9a3c06f26cf921f2308e81b26a65741491875201ede70b1
|
| 3 |
+
size 6364095
|
data/20scenarios/cache/sd_training.tfrecord-00000-of-01000_17100879b93ceb61.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:742aa8c350793d83e949396dcb055a17bcbdd4b2b728b41d5f1b5c6c5a897ce1
|
| 3 |
+
size 4382994
|
data/20scenarios/cache/sd_training.tfrecord-00000-of-01000_c403d5992cab9e0.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:716fa8212d8d4dbed60703cf5a9a952129fab5ed9342da748a7a48f00478e6d9
|
| 3 |
+
size 11279523
|
data/20scenarios/process.ipynb
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "f003e6e4",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"import os, pickle"
|
| 11 |
+
]
|
| 12 |
+
},
|
| 13 |
+
{
|
| 14 |
+
"cell_type": "code",
|
| 15 |
+
"execution_count": 4,
|
| 16 |
+
"id": "fdb6f299",
|
| 17 |
+
"metadata": {},
|
| 18 |
+
"outputs": [
|
| 19 |
+
{
|
| 20 |
+
"name": "stdout",
|
| 21 |
+
"output_type": "stream",
|
| 22 |
+
"text": [
|
| 23 |
+
"\u001b[0m\u001b[01;32mdataset_summary.pkl\u001b[0m*\r\n",
|
| 24 |
+
"\u001b[01;32mprocess.ipynb\u001b[0m*\r\n",
|
| 25 |
+
"\u001b[01;32msd_training.tfrecord-00000-of-01000_101d4e5775093d0c.pkl\u001b[0m*\r\n",
|
| 26 |
+
"\u001b[01;32msd_training.tfrecord-00000-of-01000_10c3969f1eb158d.pkl\u001b[0m*\r\n",
|
| 27 |
+
"\u001b[01;32msd_training.tfrecord-00000-of-01000_1109b0038ed8f25a.pkl\u001b[0m*\r\n",
|
| 28 |
+
"\u001b[01;32msd_training.tfrecord-00000-of-01000_116d257e98878d94.pkl\u001b[0m*\r\n",
|
| 29 |
+
"\u001b[01;32msd_training.tfrecord-00000-of-01000_11cdace2c1445900.pkl\u001b[0m*\r\n",
|
| 30 |
+
"\u001b[01;32msd_training.tfrecord-00000-of-01000_12a40f114f1ec5fc.pkl\u001b[0m*\r\n",
|
| 31 |
+
"\u001b[01;32msd_training.tfrecord-00000-of-01000_149f682e19454efa.pkl\u001b[0m*\r\n",
|
| 32 |
+
"\u001b[01;32msd_training.tfrecord-00000-of-01000_16361c8c522cf0e.pkl\u001b[0m*\r\n",
|
| 33 |
+
"\u001b[01;32msd_training.tfrecord-00000-of-01000_1663adf01133d82.pkl\u001b[0m*\r\n",
|
| 34 |
+
"\u001b[01;32msd_training.tfrecord-00000-of-01000_16b7c9a8ae6e89f1.pkl\u001b[0m*\r\n",
|
| 35 |
+
"\u001b[01;32msd_training.tfrecord-00000-of-01000_16d4f837fb57ff1.pkl\u001b[0m*\r\n",
|
| 36 |
+
"\u001b[01;32msd_training.tfrecord-00000-of-01000_17100879b93ceb61.pkl\u001b[0m*\r\n",
|
| 37 |
+
"\u001b[01;32msd_training.tfrecord-00000-of-01000_17fa718cd1251a26.pkl\u001b[0m*\r\n",
|
| 38 |
+
"\u001b[01;32msd_training.tfrecord-00000-of-01000_18840a098288507f.pkl\u001b[0m*\r\n",
|
| 39 |
+
"\u001b[01;32msd_training.tfrecord-00000-of-01000_18f5ce249ee2e949.pkl\u001b[0m*\r\n",
|
| 40 |
+
"\u001b[01;32msd_training.tfrecord-00000-of-01000_1a7143a44e480ca6.pkl\u001b[0m*\r\n",
|
| 41 |
+
"\u001b[01;32msd_training.tfrecord-00000-of-01000_1a8cc570d620bd31.pkl\u001b[0m*\r\n",
|
| 42 |
+
"\u001b[01;32msd_training.tfrecord-00000-of-01000_2a1e44d405a6833f.pkl\u001b[0m*\r\n",
|
| 43 |
+
"\u001b[01;32msd_training.tfrecord-00000-of-01000_8a346109094cd5aa.pkl\u001b[0m*\r\n",
|
| 44 |
+
"\u001b[01;32msd_training.tfrecord-00000-of-01000_c403d5992cab9e0.pkl\u001b[0m*\r\n"
|
| 45 |
+
]
|
| 46 |
+
}
|
| 47 |
+
],
|
| 48 |
+
"source": [
|
| 49 |
+
"ls"
|
| 50 |
+
]
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"cell_type": "code",
|
| 54 |
+
"execution_count": 3,
|
| 55 |
+
"id": "f08caf3b",
|
| 56 |
+
"metadata": {},
|
| 57 |
+
"outputs": [
|
| 58 |
+
{
|
| 59 |
+
"data": {
|
| 60 |
+
"text/plain": [
|
| 61 |
+
"['sd_training.tfrecord-00000-of-01000_101d4e5775093d0c.pkl',\n",
|
| 62 |
+
" 'sd_training.tfrecord-00000-of-01000_10c3969f1eb158d.pkl',\n",
|
| 63 |
+
" 'sd_training.tfrecord-00000-of-01000_1109b0038ed8f25a.pkl',\n",
|
| 64 |
+
" 'sd_training.tfrecord-00000-of-01000_116d257e98878d94.pkl',\n",
|
| 65 |
+
" 'sd_training.tfrecord-00000-of-01000_11cdace2c1445900.pkl',\n",
|
| 66 |
+
" 'sd_training.tfrecord-00000-of-01000_12a40f114f1ec5fc.pkl',\n",
|
| 67 |
+
" 'sd_training.tfrecord-00000-of-01000_149f682e19454efa.pkl',\n",
|
| 68 |
+
" 'sd_training.tfrecord-00000-of-01000_16361c8c522cf0e.pkl',\n",
|
| 69 |
+
" 'sd_training.tfrecord-00000-of-01000_1663adf01133d82.pkl',\n",
|
| 70 |
+
" 'sd_training.tfrecord-00000-of-01000_16b7c9a8ae6e89f1.pkl',\n",
|
| 71 |
+
" 'sd_training.tfrecord-00000-of-01000_16d4f837fb57ff1.pkl',\n",
|
| 72 |
+
" 'sd_training.tfrecord-00000-of-01000_17100879b93ceb61.pkl',\n",
|
| 73 |
+
" 'sd_training.tfrecord-00000-of-01000_17fa718cd1251a26.pkl',\n",
|
| 74 |
+
" 'sd_training.tfrecord-00000-of-01000_18840a098288507f.pkl',\n",
|
| 75 |
+
" 'sd_training.tfrecord-00000-of-01000_18f5ce249ee2e949.pkl',\n",
|
| 76 |
+
" 'sd_training.tfrecord-00000-of-01000_1a7143a44e480ca6.pkl',\n",
|
| 77 |
+
" 'sd_training.tfrecord-00000-of-01000_1a8cc570d620bd31.pkl',\n",
|
| 78 |
+
" 'sd_training.tfrecord-00000-of-01000_2a1e44d405a6833f.pkl',\n",
|
| 79 |
+
" 'sd_training.tfrecord-00000-of-01000_8a346109094cd5aa.pkl',\n",
|
| 80 |
+
" 'sd_training.tfrecord-00000-of-01000_c403d5992cab9e0.pkl']"
|
| 81 |
+
]
|
| 82 |
+
},
|
| 83 |
+
"execution_count": 3,
|
| 84 |
+
"metadata": {},
|
| 85 |
+
"output_type": "execute_result"
|
| 86 |
+
}
|
| 87 |
+
],
|
| 88 |
+
"source": [
|
| 89 |
+
"[p for p in os.listdir(\".\") if p.endswith(\".pkl\") and p.startswith(\"sd\")]"
|
| 90 |
+
]
|
| 91 |
+
},
|
| 92 |
+
{
|
| 93 |
+
"cell_type": "code",
|
| 94 |
+
"execution_count": 6,
|
| 95 |
+
"id": "d1d50b21",
|
| 96 |
+
"metadata": {},
|
| 97 |
+
"outputs": [],
|
| 98 |
+
"source": [
|
| 99 |
+
"d = {}\n",
|
| 100 |
+
"for p in [p for p in os.listdir(\".\") if p.endswith(\".pkl\") and p.startswith(\"sd\")]:\n",
|
| 101 |
+
" d[p] = {}\n",
|
| 102 |
+
"\n",
|
| 103 |
+
"pickle.dump(d, open(\"dataset_summary.pkl\", \"wb\"))"
|
| 104 |
+
]
|
| 105 |
+
}
|
| 106 |
+
],
|
| 107 |
+
"metadata": {
|
| 108 |
+
"kernelspec": {
|
| 109 |
+
"display_name": "Python 3 (ipykernel)",
|
| 110 |
+
"language": "python",
|
| 111 |
+
"name": "python3"
|
| 112 |
+
},
|
| 113 |
+
"language_info": {
|
| 114 |
+
"codemirror_mode": {
|
| 115 |
+
"name": "ipython",
|
| 116 |
+
"version": 3
|
| 117 |
+
},
|
| 118 |
+
"file_extension": ".py",
|
| 119 |
+
"mimetype": "text/x-python",
|
| 120 |
+
"name": "python",
|
| 121 |
+
"nbconvert_exporter": "python",
|
| 122 |
+
"pygments_lexer": "ipython3",
|
| 123 |
+
"version": "3.9.16"
|
| 124 |
+
}
|
| 125 |
+
},
|
| 126 |
+
"nbformat": 4,
|
| 127 |
+
"nbformat_minor": 5
|
| 128 |
+
}
|
data/20scenarios/sd_training.tfrecord-00000-of-01000_101d4e5775093d0c.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:de3ceb564e944fda1d5b8e2f72428ada2b790d5d57195e6d9bca2cdf761a37f0
|
| 3 |
+
size 338026
|
data/20scenarios/sd_training.tfrecord-00000-of-01000_10c3969f1eb158d.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:72cae539b2993d4dd668e5029460e2cdbf621dc1734d69ce30b357a156bc8375
|
| 3 |
+
size 535617
|
data/20scenarios/sd_training.tfrecord-00000-of-01000_1109b0038ed8f25a.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:383b8cdeaefc13c15751f3ff29d28426543c24b6de674e0c5434e39f5d7a8d1f
|
| 3 |
+
size 1060327
|
data/20scenarios/sd_training.tfrecord-00000-of-01000_116d257e98878d94.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:77209cb776b340b849ff123c7c9019d17495df41bf69fe2cfc765d5c8235fc66
|
| 3 |
+
size 680549
|
data/20scenarios/sd_training.tfrecord-00000-of-01000_11cdace2c1445900.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bf6fc9426a5b209d4f0862ff8046844f40fbf822fcf54a24947f23aa0f140161
|
| 3 |
+
size 426242
|
data/20scenarios/sd_training.tfrecord-00000-of-01000_12a40f114f1ec5fc.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ddfd0775eedec97b30316beb8c537e3d768928abbd255030403c5bdff59a7f85
|
| 3 |
+
size 837708
|
data/20scenarios/sd_training.tfrecord-00000-of-01000_149f682e19454efa.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3b2f40bfe73b50954987f15a462a4a86659615e416b577b1c1096c8c69ef0d05
|
| 3 |
+
size 667369
|
data/20scenarios/sd_training.tfrecord-00000-of-01000_16361c8c522cf0e.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:30f1f5e1d3a5863a88b919a79c8c4a75bff2774df99412a092706d5cebc7ffbb
|
| 3 |
+
size 319336
|
data/20scenarios/sd_training.tfrecord-00000-of-01000_1663adf01133d82.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bacb875cdcd59d4c5ada7c1a2b546a31967de6f15f288d81136d2f3cc12f2413
|
| 3 |
+
size 560466
|
data/20scenarios/sd_training.tfrecord-00000-of-01000_16b7c9a8ae6e89f1.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5194b6697b1fa3d00113c8ba2f39d50ffb8066b004c8f2f5d15953a550027e7c
|
| 3 |
+
size 456308
|
data/20scenarios/sd_training.tfrecord-00000-of-01000_16d4f837fb57ff1.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dfce54f80da166a8cbaf5fe167770975dd4dfcc9bba3c29ad228ad112d995690
|
| 3 |
+
size 396202
|
data/20scenarios/sd_training.tfrecord-00000-of-01000_17100879b93ceb61.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a15fb32533b7472c38ec32c676a9425490de2c95d8dd39b9aa2a6b98f1050512
|
| 3 |
+
size 253625
|
data/20scenarios/sd_training.tfrecord-00000-of-01000_17fa718cd1251a26.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d126e0c7c1e4afa561739437ad7c7810830c766a45b725fb8cb8a33015029e61
|
| 3 |
+
size 490897
|
data/20scenarios/sd_training.tfrecord-00000-of-01000_18840a098288507f.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ffbd89555db615f2d20516fedf2dd3689b4487221a82c9f48251e29c30027f31
|
| 3 |
+
size 627272
|
data/20scenarios/sd_training.tfrecord-00000-of-01000_18f5ce249ee2e949.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4e93b1b9cbfd8c541c5c52523aee09f5abd3410d217bf6e402ec217935f5df08
|
| 3 |
+
size 435189
|
data/20scenarios/sd_training.tfrecord-00000-of-01000_1a7143a44e480ca6.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:11ac7db4a7c917c8045c957ee93d2f761de089152fdb9b03554a889d4f1deca3
|
| 3 |
+
size 376448
|
data/20scenarios/sd_training.tfrecord-00000-of-01000_1a8cc570d620bd31.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5cb753a3c0d7255b5d4333a95c5555b4b159dcc600f9da286787ab37893c70d8
|
| 3 |
+
size 1558365
|
data/20scenarios/sd_training.tfrecord-00000-of-01000_2a1e44d405a6833f.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7ac83f07304e39a8c5e468399b12e3ed034d6855b2921ab094ee425e492bea6d
|
| 3 |
+
size 1319602
|
data/20scenarios/sd_training.tfrecord-00000-of-01000_8a346109094cd5aa.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c81210cec08ec7bc5813e59bf4a2f8287800c2d626929441a0e7398f7283c509
|
| 3 |
+
size 1316507
|
data/20scenarios/sd_training.tfrecord-00000-of-01000_c403d5992cab9e0.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:26feb91313bee0c89ff1a9660fc352ea687daf04a7ca642f600d2ece17a6e301
|
| 3 |
+
size 1079778
|
packages.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ffmpeg
|
| 2 |
+
libgl1
|
| 3 |
+
libglib2.0-0
|
| 4 |
+
libsm6
|
| 5 |
+
libxext6
|
| 6 |
+
libxrender1
|
| 7 |
+
libsdl2-2.0-0
|
pyproject.toml
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "scenestreamer"
|
| 3 |
+
version = "1.0.0"
|
| 4 |
+
description = "SceneStreamer: Continuous Scenario Generation as Next Token Group Prediction"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
license = {text = "MIT"}
|
| 7 |
+
authors = [
|
| 8 |
+
{name = "Zhenghao Peng", email = "pzh@berkeley.edu"}
|
| 9 |
+
]
|
| 10 |
+
requires-python = ">=3.10,<3.12"
|
| 11 |
+
classifiers = [
|
| 12 |
+
"Development Status :: 4 - Beta",
|
| 13 |
+
"License :: OSI Approved :: MIT License",
|
| 14 |
+
"Programming Language :: Python :: 3.10",
|
| 15 |
+
"Programming Language :: Python :: 3.11",
|
| 16 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
dependencies = [
|
| 20 |
+
"torch>=2.0.0",
|
| 21 |
+
"torchvision",
|
| 22 |
+
"lightning>=2.0.0",
|
| 23 |
+
"hydra-core",
|
| 24 |
+
"omegaconf",
|
| 25 |
+
"numpy",
|
| 26 |
+
"tqdm",
|
| 27 |
+
"matplotlib",
|
| 28 |
+
"seaborn",
|
| 29 |
+
"Pillow",
|
| 30 |
+
"easydict",
|
| 31 |
+
"wandb",
|
| 32 |
+
"torch_geometric",
|
| 33 |
+
"transformers",
|
| 34 |
+
"tokenizers",
|
| 35 |
+
"huggingface_hub",
|
| 36 |
+
"tensorboardX",
|
| 37 |
+
"pyyaml",
|
| 38 |
+
"scikit-image",
|
| 39 |
+
"chardet",
|
| 40 |
+
"charset-normalizer",
|
| 41 |
+
"tabulate",
|
| 42 |
+
"metadrive-simulator",
|
| 43 |
+
"gradio>=6.9.0",
|
| 44 |
+
"scenarionet @ git+https://github.com/metadriverse/scenarionet.git",
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
[project.scripts]
|
| 48 |
+
scenestreamer = "scenestreamer.cli:main"
|
| 49 |
+
|
| 50 |
+
[project.optional-dependencies]
|
| 51 |
+
dev = [
|
| 52 |
+
"ruff",
|
| 53 |
+
"pytest",
|
| 54 |
+
]
|
| 55 |
+
rl = [
|
| 56 |
+
"stable-baselines3>=2.0.0",
|
| 57 |
+
"gymnasium>=0.29.0",
|
| 58 |
+
"ipython",
|
| 59 |
+
]
|
| 60 |
+
# Note: waymo-open-dataset requires Python 3.10 and specific numpy versions.
|
| 61 |
+
# Install separately: pip install waymo-open-dataset-tf-2-12-0==1.6.4
|
| 62 |
+
|
| 63 |
+
[project.urls]
|
| 64 |
+
Homepage = "https://vail-ucla.github.io/scenestreamer/"
|
| 65 |
+
Repository = "https://github.com/pengzhenghao/scenestreamer"
|
| 66 |
+
|
| 67 |
+
[build-system]
|
| 68 |
+
requires = ["setuptools>=61.0"]
|
| 69 |
+
build-backend = "setuptools.build_meta"
|
| 70 |
+
|
| 71 |
+
[tool.setuptools.packages.find]
|
| 72 |
+
include = ["scenestreamer*"]
|
| 73 |
+
|
| 74 |
+
[tool.ruff]
|
| 75 |
+
line-length = 120
|
| 76 |
+
target-version = "py310"
|
| 77 |
+
|
| 78 |
+
[tool.ruff.lint]
|
| 79 |
+
select = ["E", "F", "W"]
|
| 80 |
+
ignore = ["E501"]
|
requirements.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
.
|
scenestreamer/__init__.py
ADDED
|
File without changes
|
scenestreamer/cli.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
import pathlib
|
| 7 |
+
import runpy
|
| 8 |
+
import sys
|
| 9 |
+
import time
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
from typing import Any
|
| 12 |
+
|
| 13 |
+
import yaml
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _to_plain(obj: Any) -> Any:
|
| 17 |
+
if hasattr(obj, "items"):
|
| 18 |
+
return {k: _to_plain(v) for k, v in obj.items()}
|
| 19 |
+
if isinstance(obj, (list, tuple)):
|
| 20 |
+
return [_to_plain(v) for v in obj]
|
| 21 |
+
return obj
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _to_easydict(obj: Any):
|
| 25 |
+
from easydict import EasyDict
|
| 26 |
+
|
| 27 |
+
if isinstance(obj, dict):
|
| 28 |
+
return EasyDict({k: _to_easydict(v) for k, v in obj.items()})
|
| 29 |
+
if isinstance(obj, list):
|
| 30 |
+
return [_to_easydict(v) for v in obj]
|
| 31 |
+
return obj
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def load_yaml_config(path: str | os.PathLike[str]):
|
| 35 |
+
with open(path, "r") as f:
|
| 36 |
+
data = yaml.safe_load(f)
|
| 37 |
+
return _to_easydict(data)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def apply_overrides(cfg, overrides: list[str]) -> None:
|
| 41 |
+
"""
|
| 42 |
+
Apply overrides of form KEY=VALUE where KEY is dot-delimited.
|
| 43 |
+
VALUE is parsed using yaml.safe_load (so numbers/bools/lists work).
|
| 44 |
+
"""
|
| 45 |
+
for item in overrides:
|
| 46 |
+
if "=" not in item:
|
| 47 |
+
raise ValueError(f"Invalid override (expected KEY=VALUE): {item}")
|
| 48 |
+
key, raw_val = item.split("=", 1)
|
| 49 |
+
value = yaml.safe_load(raw_val)
|
| 50 |
+
|
| 51 |
+
cur = cfg
|
| 52 |
+
parts = key.split(".")
|
| 53 |
+
for p in parts[:-1]:
|
| 54 |
+
if not hasattr(cur, p):
|
| 55 |
+
setattr(cur, p, _to_easydict({}))
|
| 56 |
+
cur = getattr(cur, p)
|
| 57 |
+
setattr(cur, parts[-1], value)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@dataclass(frozen=True)
|
| 61 |
+
class RunPaths:
|
| 62 |
+
run_dir: pathlib.Path
|
| 63 |
+
config_path: pathlib.Path
|
| 64 |
+
metrics_path: pathlib.Path
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def make_run_dir(base_dir: str | os.PathLike[str], run_id: str | None) -> RunPaths:
|
| 68 |
+
base = pathlib.Path(base_dir)
|
| 69 |
+
base.mkdir(parents=True, exist_ok=True)
|
| 70 |
+
if run_id is None:
|
| 71 |
+
run_id = time.strftime("%Y%m%d-%H%M%S")
|
| 72 |
+
run_dir = base / run_id
|
| 73 |
+
run_dir.mkdir(parents=True, exist_ok=False)
|
| 74 |
+
return RunPaths(
|
| 75 |
+
run_dir=run_dir,
|
| 76 |
+
config_path=run_dir / "config.yaml",
|
| 77 |
+
metrics_path=run_dir / "metrics.json",
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def cmd_preprocess(args: argparse.Namespace) -> None:
|
| 82 |
+
# Prefer failing fast on missing ScenarioNet without importing heavy deps.
|
| 83 |
+
try:
|
| 84 |
+
import scenarionet # noqa: F401
|
| 85 |
+
except ModuleNotFoundError as e:
|
| 86 |
+
raise e
|
| 87 |
+
|
| 88 |
+
from scenestreamer.dataset.dataset import SceneStreamerDataset
|
| 89 |
+
|
| 90 |
+
cfg = load_yaml_config(args.config)
|
| 91 |
+
apply_overrides(cfg, args.set or [])
|
| 92 |
+
|
| 93 |
+
# Paths: prefer CLI args, but allow config overrides.
|
| 94 |
+
if args.train_dir:
|
| 95 |
+
cfg.DATA.TRAINING_DATA_DIR = args.train_dir
|
| 96 |
+
if args.test_dir:
|
| 97 |
+
cfg.DATA.TEST_DATA_DIR = args.test_dir
|
| 98 |
+
|
| 99 |
+
cfg.DATA.USE_CACHE = True
|
| 100 |
+
|
| 101 |
+
run = make_run_dir(args.artifacts_dir, args.run_id)
|
| 102 |
+
with open(run.config_path, "w") as f:
|
| 103 |
+
yaml.safe_dump(_to_plain(cfg), f, sort_keys=False)
|
| 104 |
+
|
| 105 |
+
mode = args.split
|
| 106 |
+
ds = SceneStreamerDataset(cfg, mode)
|
| 107 |
+
|
| 108 |
+
# Iterate to materialize cache files.
|
| 109 |
+
for i in range(len(ds)):
|
| 110 |
+
_ = ds[i]
|
| 111 |
+
if args.limit is not None and (i + 1) >= args.limit:
|
| 112 |
+
break
|
| 113 |
+
|
| 114 |
+
metrics = {
|
| 115 |
+
"status": "ok",
|
| 116 |
+
"mode": mode,
|
| 117 |
+
"train_dir": getattr(cfg.DATA, "TRAINING_DATA_DIR", None),
|
| 118 |
+
"test_dir": getattr(cfg.DATA, "TEST_DATA_DIR", None),
|
| 119 |
+
"limit": args.limit,
|
| 120 |
+
}
|
| 121 |
+
with open(run.metrics_path, "w") as f:
|
| 122 |
+
json.dump(metrics, f, indent=2)
|
| 123 |
+
|
| 124 |
+
print(str(run.run_dir))
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def _load_model_from_args(args: argparse.Namespace):
|
| 128 |
+
import torch
|
| 129 |
+
|
| 130 |
+
from scenestreamer.utils import utils
|
| 131 |
+
|
| 132 |
+
device = torch.device(args.device)
|
| 133 |
+
if args.hf_repo:
|
| 134 |
+
return utils.get_model(huggingface_repo=args.hf_repo, huggingface_file=args.hf_file, device=device)
|
| 135 |
+
if args.ckpt:
|
| 136 |
+
return utils.get_model(checkpoint_path=args.ckpt, device=device)
|
| 137 |
+
raise ValueError("Must provide either --hf-repo/--hf-file or --ckpt")
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def cmd_table1(args: argparse.Namespace) -> None:
|
| 141 |
+
from scenestreamer.paper.table1_mmd import run_table1_mmd
|
| 142 |
+
|
| 143 |
+
pl_model = _load_model_from_args(args)
|
| 144 |
+
run_dir = run_table1_mmd(
|
| 145 |
+
pl_model=pl_model,
|
| 146 |
+
dataset_dir=args.dataset_dir,
|
| 147 |
+
split=args.split,
|
| 148 |
+
limit=args.limit,
|
| 149 |
+
artifacts_dir=args.artifacts_dir,
|
| 150 |
+
run_id=args.run_id,
|
| 151 |
+
seed=args.seed,
|
| 152 |
+
)
|
| 153 |
+
print(str(run_dir))
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def cmd_table2(args: argparse.Namespace) -> None:
|
| 157 |
+
from scenestreamer.paper.table2_motion import run_table2_motion
|
| 158 |
+
|
| 159 |
+
pl_model = _load_model_from_args(args)
|
| 160 |
+
run_dir = run_table2_motion(
|
| 161 |
+
pl_model=pl_model,
|
| 162 |
+
dataset_dir=args.dataset_dir,
|
| 163 |
+
split=args.split,
|
| 164 |
+
mode=args.mode,
|
| 165 |
+
num_modes=args.num_modes,
|
| 166 |
+
limit=args.limit,
|
| 167 |
+
artifacts_dir=args.artifacts_dir,
|
| 168 |
+
run_id=args.run_id,
|
| 169 |
+
seed=args.seed,
|
| 170 |
+
)
|
| 171 |
+
print(str(run_dir))
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def cmd_densify_demo(args: argparse.Namespace) -> None:
|
| 175 |
+
from scenestreamer.paper.densify_demo import run_densify_demo
|
| 176 |
+
|
| 177 |
+
pl_model = _load_model_from_args(args)
|
| 178 |
+
run_dir = run_densify_demo(
|
| 179 |
+
pl_model=pl_model,
|
| 180 |
+
dataset_dir=args.dataset_dir,
|
| 181 |
+
split=args.split,
|
| 182 |
+
scenario_index=args.scenario_index,
|
| 183 |
+
max_agents=args.max_agents,
|
| 184 |
+
force_no_end=args.force_no_end,
|
| 185 |
+
artifacts_dir=args.artifacts_dir,
|
| 186 |
+
run_id=args.run_id,
|
| 187 |
+
seed=args.seed,
|
| 188 |
+
)
|
| 189 |
+
print(str(run_dir))
|
| 190 |
+
|
| 191 |
+
def _run_module_as_main(module: str, argv: list[str]) -> None:
|
| 192 |
+
old_argv = sys.argv[:]
|
| 193 |
+
try:
|
| 194 |
+
sys.argv = [module] + argv
|
| 195 |
+
runpy.run_module(module, run_name="__main__")
|
| 196 |
+
finally:
|
| 197 |
+
sys.argv = old_argv
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def cmd_table3_train(args: argparse.Namespace) -> None:
|
| 201 |
+
_run_module_as_main("scenestreamer.rl_train.train.train_td3", args.args)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def cmd_table3_eval(args: argparse.Namespace) -> None:
|
| 205 |
+
_run_module_as_main("scenestreamer.rl_train.train.eval_policy", args.args)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 209 |
+
parser = argparse.ArgumentParser(prog="scenestreamer", description="SceneStreamer paper reproduction CLI")
|
| 210 |
+
sub = parser.add_subparsers(dest="cmd", required=True)
|
| 211 |
+
|
| 212 |
+
def add_run_args(p: argparse.ArgumentParser) -> None:
|
| 213 |
+
p.add_argument("--artifacts-dir", default="artifacts", help="Directory to write run artifacts")
|
| 214 |
+
p.add_argument("--run-id", default=None, help="Run ID (default: timestamp)")
|
| 215 |
+
p.add_argument("--seed", type=int, default=0)
|
| 216 |
+
|
| 217 |
+
def add_model_args(p: argparse.ArgumentParser) -> None:
|
| 218 |
+
p.add_argument("--device", default="cuda", help="torch device string, e.g. cuda or cpu")
|
| 219 |
+
p.add_argument("--ckpt", default=None, help="Path to a .ckpt checkpoint")
|
| 220 |
+
p.add_argument("--hf-repo", default=None, help="HuggingFace repo id, e.g. user/repo")
|
| 221 |
+
p.add_argument("--hf-file", default=None, help="HuggingFace filename, e.g. model.ckpt")
|
| 222 |
+
|
| 223 |
+
# preprocess
|
| 224 |
+
p = sub.add_parser("preprocess", help="Preprocess ScenarioNet SD dataset and build cache")
|
| 225 |
+
add_run_args(p)
|
| 226 |
+
p.add_argument("--config", default="cfgs/motion_default.yaml")
|
| 227 |
+
p.add_argument("--set", action="append", default=[], help="Override config KEY=VALUE (repeatable)")
|
| 228 |
+
p.add_argument("--train-dir", default=None)
|
| 229 |
+
p.add_argument("--test-dir", default=None)
|
| 230 |
+
p.add_argument("--split", choices=["training", "test"], default="training")
|
| 231 |
+
p.add_argument("--limit", type=int, default=None)
|
| 232 |
+
p.set_defaults(func=cmd_preprocess)
|
| 233 |
+
|
| 234 |
+
# table1
|
| 235 |
+
p = sub.add_parser("table1", help="Table 1: initial state MMD (strict + relaxed)")
|
| 236 |
+
add_run_args(p)
|
| 237 |
+
add_model_args(p)
|
| 238 |
+
p.add_argument("--dataset-dir", required=True)
|
| 239 |
+
p.add_argument("--split", choices=["training", "test"], default="test")
|
| 240 |
+
p.add_argument("--limit", type=int, default=None)
|
| 241 |
+
p.set_defaults(func=cmd_table1)
|
| 242 |
+
|
| 243 |
+
# table2
|
| 244 |
+
p = sub.add_parser("table2", help="Table 2: motion prediction (ADE/FDE + ADD/FDD)")
|
| 245 |
+
add_run_args(p)
|
| 246 |
+
add_model_args(p)
|
| 247 |
+
p.add_argument("--dataset-dir", required=True)
|
| 248 |
+
p.add_argument("--split", choices=["training", "test"], default="test")
|
| 249 |
+
p.add_argument("--mode", choices=["motion", "full"], default="motion")
|
| 250 |
+
p.add_argument("--num-modes", type=int, default=6)
|
| 251 |
+
p.add_argument("--limit", type=int, default=None)
|
| 252 |
+
p.set_defaults(func=cmd_table2)
|
| 253 |
+
|
| 254 |
+
# demo
|
| 255 |
+
p = sub.add_parser("densify-demo", help="Qualitative densification demo (generate to max agents)")
|
| 256 |
+
add_run_args(p)
|
| 257 |
+
add_model_args(p)
|
| 258 |
+
p.add_argument("--dataset-dir", required=True)
|
| 259 |
+
p.add_argument("--split", choices=["training", "test"], default="test")
|
| 260 |
+
p.add_argument("--scenario-index", type=int, default=0)
|
| 261 |
+
p.add_argument("--max-agents", type=int, default=128)
|
| 262 |
+
p.add_argument("--force-no-end", action="store_true", help="Disable end token so it keeps generating agents")
|
| 263 |
+
p.set_defaults(func=cmd_densify_demo)
|
| 264 |
+
|
| 265 |
+
p = sub.add_parser("table3-train", help="Table 3: RL training (pass-through to train_td3.py)")
|
| 266 |
+
p.add_argument("args", nargs=argparse.REMAINDER, help="Arguments forwarded to train_td3.py")
|
| 267 |
+
p.set_defaults(func=cmd_table3_train)
|
| 268 |
+
|
| 269 |
+
p = sub.add_parser("table3-eval", help="Table 3: RL evaluation (pass-through to eval_policy.py)")
|
| 270 |
+
p.add_argument("args", nargs=argparse.REMAINDER, help="Arguments forwarded to eval_policy.py")
|
| 271 |
+
p.set_defaults(func=cmd_table3_eval)
|
| 272 |
+
|
| 273 |
+
return parser
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def main(argv: list[str] | None = None) -> None:
|
| 277 |
+
parser = build_parser()
|
| 278 |
+
args = parser.parse_args(argv)
|
| 279 |
+
try:
|
| 280 |
+
args.func(args)
|
| 281 |
+
except ModuleNotFoundError as e:
|
| 282 |
+
# Most common in a fresh environment: scenarionet / waymo-open-dataset missing.
|
| 283 |
+
msg = str(e)
|
| 284 |
+
if "scenarionet" in msg:
|
| 285 |
+
raise SystemExit(
|
| 286 |
+
"Missing dependency 'scenarionet'. Install it via:\n"
|
| 287 |
+
" pip install git+https://github.com/metadriverse/scenarionet.git\n"
|
| 288 |
+
) from e
|
| 289 |
+
raise
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
if __name__ == "__main__":
|
| 293 |
+
main()
|
scenestreamer/clustering.sh
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
nohup python clustering.py --data 3 > clustering_obj3_all_nomin.log 2>&1 &
|
| 2 |
+
nohup python clustering.py --data 2 > clustering_obj2_all_nomin.log 2>&1 &
|
| 3 |
+
nohup python clustering.py --data 1 > clustering_obj1_all_nomin.log 2>&1 &
|
| 4 |
+
|
| 5 |
+
#nohup python clustering.py --data 3 --min_scale 0.5 > clustering_obj3_all.log 2>&1 &
|
| 6 |
+
#nohup python clustering.py --data 2 --min_scale 0.5 > clustering_obj2_all.log 2>&1 &
|
| 7 |
+
#nohup python clustering.py --data 1 --min_scale 0.5 > clustering_obj1_all.log 2>&1 &
|
scenestreamer/dataset/__init__.py
ADDED
|
File without changes
|
scenestreamer/dataset/constants.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Define a lot of constants. It should be totally removed as most of them should be defined by MetaDrive / ScenarioNet.
|
| 3 |
+
"""
|
| 4 |
+
from metadrive.scenario.scenario_description import MetaDriveType
|
| 5 |
+
|
| 6 |
+
# NUM_TYPES = 3
|
| 7 |
+
NUM_TYPES = 5
|
| 8 |
+
|
| 9 |
+
MAP_FEATURE_STATE_DIM = 27
|
| 10 |
+
TRAFFIC_LIGHT_STATE_DIM = 7
|
| 11 |
+
|
| 12 |
+
AGENT_STATE_DIM = 16
|
| 13 |
+
|
| 14 |
+
# ACTOR_PREDICT_DIM = 6 + 2 + 4 + 5 # 3 for position, 1 for heading, 2 for velocity, 5 for types
|
| 15 |
+
TRAFFIC_LIGHT_PREDICT_DIM = 9 # 9 original possible state
|
| 16 |
+
|
| 17 |
+
# TODO(pzh): Do we have to do the normalization? Shouldn't the layer norm solve this?
|
| 18 |
+
# POSITION_XY_RANGE = 100.
|
| 19 |
+
# LOCAL_POSITION_XY_RANGE = 5.
|
| 20 |
+
# HEADING_RANGE = np.pi
|
| 21 |
+
# VELOCITY_XY_RANGE = 10.
|
| 22 |
+
# SIZE_RANGE = 5.
|
| 23 |
+
# MAP_VECTOR_XY_RANGE = 50.
|
| 24 |
+
|
| 25 |
+
# TODO(pzh): Consider remove this.
|
| 26 |
+
object_type_to_int = {
|
| 27 |
+
MetaDriveType.UNSET: 0,
|
| 28 |
+
MetaDriveType.VEHICLE: 1,
|
| 29 |
+
MetaDriveType.PEDESTRIAN: 2,
|
| 30 |
+
MetaDriveType.CYCLIST: 3,
|
| 31 |
+
MetaDriveType.OTHER: 4
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
# TODO(pzh): Consider remove this.
|
| 35 |
+
object_int_to_type = {
|
| 36 |
+
-1: MetaDriveType.UNSET,
|
| 37 |
+
0: MetaDriveType.UNSET,
|
| 38 |
+
1: MetaDriveType.VEHICLE,
|
| 39 |
+
2: MetaDriveType.PEDESTRIAN,
|
| 40 |
+
3: MetaDriveType.CYCLIST,
|
| 41 |
+
4: MetaDriveType.OTHER
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
HEADING_PLACEHOLDER = -100 # For the object that has no heading, set this.
|
scenestreamer/dataset/datamodule.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This is a wrapper to wrap our dataset as a lightning datamodule.
|
| 3 |
+
"""
|
| 4 |
+
import lightning.pytorch as pl
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
|
| 7 |
+
from scenestreamer.dataset import dataset
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SceneStreamerDataModule(pl.LightningDataModule):
|
| 11 |
+
def __init__(
|
| 12 |
+
self, config, train_batch_size, train_num_workers, train_prefetch_factor, val_batch_size, val_num_workers,
|
| 13 |
+
val_prefetch_factor
|
| 14 |
+
):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.config = config
|
| 17 |
+
self.train_batch_size = train_batch_size
|
| 18 |
+
self.train_num_workers = train_num_workers
|
| 19 |
+
self.train_prefetch_factor = train_prefetch_factor
|
| 20 |
+
self.val_batch_size = val_batch_size
|
| 21 |
+
self.val_num_workers = val_num_workers
|
| 22 |
+
self.val_prefetch_factor = val_prefetch_factor
|
| 23 |
+
|
| 24 |
+
def setup(self, stage: str):
|
| 25 |
+
self.train_dataset = dataset.SceneStreamerDataset(config=self.config, mode="training")
|
| 26 |
+
self.val_dataset = dataset.SceneStreamerDataset(config=self.config, mode="test")
|
| 27 |
+
|
| 28 |
+
def train_dataloader(self):
|
| 29 |
+
return DataLoader(
|
| 30 |
+
self.train_dataset,
|
| 31 |
+
batch_size=self.train_batch_size,
|
| 32 |
+
pin_memory=True,
|
| 33 |
+
num_workers=self.train_num_workers,
|
| 34 |
+
shuffle=True,
|
| 35 |
+
persistent_workers=True if self.train_num_workers > 0 else False,
|
| 36 |
+
collate_fn=self.train_dataset.collate_batch,
|
| 37 |
+
prefetch_factor=self.train_prefetch_factor if self.train_num_workers > 0 else None,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
def val_dataloader(self):
|
| 41 |
+
return DataLoader(
|
| 42 |
+
self.val_dataset,
|
| 43 |
+
batch_size=self.val_batch_size,
|
| 44 |
+
pin_memory=True,
|
| 45 |
+
num_workers=self.val_num_workers,
|
| 46 |
+
shuffle=False,
|
| 47 |
+
collate_fn=self.val_dataset.collate_batch,
|
| 48 |
+
prefetch_factor=self.val_prefetch_factor if self.val_num_workers > 0 else None,
|
| 49 |
+
)
|
scenestreamer/dataset/dataset.py
ADDED
|
@@ -0,0 +1,630 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Create a pytorch dataset class for loading scenario files and padding data entries.
|
| 3 |
+
"""
|
| 4 |
+
import copy
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import pathlib
|
| 8 |
+
import pickle
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
import hydra
|
| 12 |
+
except ModuleNotFoundError: # optional for core library usage
|
| 13 |
+
hydra = None
|
| 14 |
+
import numpy as np
|
| 15 |
+
from scenarionet import read_dataset_summary, read_scenario
|
| 16 |
+
from torch.utils.data import Dataset
|
| 17 |
+
|
| 18 |
+
from scenestreamer.dataset.preprocessor import preprocess_scenario_description
|
| 19 |
+
from scenestreamer.utils import global_config
|
| 20 |
+
from scenestreamer.utils import utils
|
| 21 |
+
|
| 22 |
+
# import lmdb
|
| 23 |
+
|
| 24 |
+
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
|
| 25 |
+
QA_DATASET_MAPPING = {}
|
| 26 |
+
ADV_INFO_DICT = {}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class NoMapFeatureError(Exception):
|
| 30 |
+
pass
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class LMDBDatasetReader:
|
| 34 |
+
def __init__(self, base_path):
|
| 35 |
+
self.base_path = base_path
|
| 36 |
+
# Load the lookup table that maps sample keys to LMDB file names
|
| 37 |
+
# Search recursively all subfolder to find lookup.json
|
| 38 |
+
self.lookup = {}
|
| 39 |
+
for root, dirs, files in os.walk(self.base_path):
|
| 40 |
+
if "lookup.json" in files:
|
| 41 |
+
lookup_path = os.path.join(root, "lookup.json")
|
| 42 |
+
with open(lookup_path, "r") as f:
|
| 43 |
+
lookup = json.load(f)
|
| 44 |
+
self.lookup.update(lookup)
|
| 45 |
+
self.lmdb_cache = {} # Cache for open LMDB environments
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# def _get_lmdb_env(self, lmdb_name):
|
| 49 |
+
# """Fetches or opens an LMDB environment for reading."""
|
| 50 |
+
# if lmdb_name not in self.lmdb_cache:
|
| 51 |
+
# self.lmdb_cache[lmdb_name] = lmdb.open(lmdb_name, readonly=True)
|
| 52 |
+
# return self.lmdb_cache[lmdb_name]
|
| 53 |
+
|
| 54 |
+
# def load_sample(self, key):
|
| 55 |
+
# """Loads a preprocessed sample by key."""
|
| 56 |
+
# lmdb_name = self.lookup.get(key)
|
| 57 |
+
# if lmdb_name is None:
|
| 58 |
+
# raise KeyError(f"Sample {key} not found in lookup.")
|
| 59 |
+
# env = self._get_lmdb_env(lmdb_name)
|
| 60 |
+
# with env.begin() as txn:
|
| 61 |
+
# npz_bytes = txn.get(key.encode('ascii'))
|
| 62 |
+
# if npz_bytes:
|
| 63 |
+
# with io.BytesIO(npz_bytes) as buffer:
|
| 64 |
+
# data = np.load(buffer, allow_pickle=True)
|
| 65 |
+
# return {name: data[name] for name in data.files} # Return data as a dictionary
|
| 66 |
+
# return None
|
| 67 |
+
|
| 68 |
+
# def close(self):
|
| 69 |
+
# """Closes all open LMDB environments."""
|
| 70 |
+
# for env in self.lmdb_cache.values():
|
| 71 |
+
# env.close()
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def process_QA_text_label(QA_dict):
|
| 75 |
+
# TODO: do we need to form label for each individual agent? Rightnow it is just a single label
|
| 76 |
+
labels = {}
|
| 77 |
+
|
| 78 |
+
env_a = QA_dict['env_a']
|
| 79 |
+
labels['env'] = ' '.join(env_a)
|
| 80 |
+
|
| 81 |
+
ego_a = QA_dict['ego_a']
|
| 82 |
+
labels['ego'] = ' '.join(ego_a)
|
| 83 |
+
|
| 84 |
+
int_a = QA_dict['int_a']
|
| 85 |
+
labels['int'] = ' '.join(int_a)
|
| 86 |
+
|
| 87 |
+
return labels
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def get_file_paths(directory):
|
| 91 |
+
file_paths = []
|
| 92 |
+
# Traverse the directory
|
| 93 |
+
for root, dirs, files in os.walk(directory):
|
| 94 |
+
for file in files:
|
| 95 |
+
# Get the full path and add it to the list
|
| 96 |
+
full_path = os.path.join(root, file)
|
| 97 |
+
file_paths.append(full_path)
|
| 98 |
+
return
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def load_json_to_dict(file_path):
|
| 102 |
+
"""
|
| 103 |
+
Load a JSON file into a Python dictionary.
|
| 104 |
+
|
| 105 |
+
:param file_path: Path to the JSON file
|
| 106 |
+
:return: Dictionary containing the JSON data
|
| 107 |
+
"""
|
| 108 |
+
try:
|
| 109 |
+
with open(file_path, 'r') as file:
|
| 110 |
+
data = json.load(file)
|
| 111 |
+
return data
|
| 112 |
+
except FileNotFoundError:
|
| 113 |
+
print(f"Error: The file at {file_path} was not found.")
|
| 114 |
+
except json.JSONDecodeError:
|
| 115 |
+
print(f"Error: The file at {file_path} is not a valid JSON file.")
|
| 116 |
+
except Exception as e:
|
| 117 |
+
print(f"An unexpected error occurred: {e}")
|
| 118 |
+
|
| 119 |
+
return None
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class SceneStreamerDataset(Dataset):
|
| 123 |
+
"""
|
| 124 |
+
SceneStreamer dataset class. Returns data_dict for each scenario.
|
| 125 |
+
Init args:
|
| 126 |
+
mode: "training" or "test".
|
| 127 |
+
config:
|
| 128 |
+
- model: Details about the model architecture.
|
| 129 |
+
- data: Data directories, sample intervals, number of agents, etc.
|
| 130 |
+
- evaluation: predict_all_agents, delete_eval_result (TODO: Add ScenarioDescription passthrough as a flag in the config.)
|
| 131 |
+
- optimization: Training hyperparameters.
|
| 132 |
+
- preprocessing: Max number of agents, map features, traffic lights, padding, etc.
|
| 133 |
+
- root_dir: Self-explanatory.
|
| 134 |
+
- sampling: Inference sampling parameters.
|
| 135 |
+
- tokenization: The part of the config passed to the tokenizer.
|
| 136 |
+
"""
|
| 137 |
+
def __init__(self, config, mode):
|
| 138 |
+
super().__init__()
|
| 139 |
+
self.mode = mode
|
| 140 |
+
self.config = config
|
| 141 |
+
dataset_cfg = self.config.DATA
|
| 142 |
+
|
| 143 |
+
self.max_map_features = config.PREPROCESSING.MAX_MAP_FEATURES
|
| 144 |
+
self.max_vectors_per_map_feature = config.PREPROCESSING.MAX_VECTORS
|
| 145 |
+
self.max_agents = config.PREPROCESSING.MAX_AGENTS
|
| 146 |
+
self.max_traffic_lights = config.PREPROCESSING.MAX_TRAFFIC_LIGHTS
|
| 147 |
+
self.padding_to_max = config.PREPROCESSING.PADDING_TO_MAX
|
| 148 |
+
|
| 149 |
+
# We are expecting the data_dir to be either an absolute path or a relative path w.r.t. the repo root.
|
| 150 |
+
if mode == "training":
|
| 151 |
+
self.data_dir = global_config.ROOT_DIR / dataset_cfg.TRAINING_DATA_DIR
|
| 152 |
+
elif mode == "test":
|
| 153 |
+
self.data_dir = global_config.ROOT_DIR / dataset_cfg.TEST_DATA_DIR
|
| 154 |
+
else:
|
| 155 |
+
raise ValueError(f"Unknown mode {mode}.")
|
| 156 |
+
|
| 157 |
+
# summary_dict: A dictionary of .pkl filenames to ingest. Filenames (keys) are mapped to metadata objects.
|
| 158 |
+
# summary_list: Keys of summary_dict, in order of ingestion.
|
| 159 |
+
# mapping: A dict mapping scenario IDs to the folder that hosts their files.
|
| 160 |
+
summary_dict, summary_list, mapping = read_dataset_summary(self.data_dir)
|
| 161 |
+
|
| 162 |
+
# We might want to use a subset of scenarios.
|
| 163 |
+
if self.mode == "training":
|
| 164 |
+
interval = dataset_cfg.SAMPLE_INTERVAL_TRAINING
|
| 165 |
+
elif self.mode == "test":
|
| 166 |
+
interval = dataset_cfg.SAMPLE_INTERVAL_TEST
|
| 167 |
+
else:
|
| 168 |
+
raise ValueError(f"Unknown mode {self.mode}.")
|
| 169 |
+
|
| 170 |
+
if "SD_PASSTHROUGH" in config.DATA:
|
| 171 |
+
self.return_scenario_description = config.DATA["SD_PASSTHROUGH"]
|
| 172 |
+
else: # Default to False.
|
| 173 |
+
self.return_scenario_description = False
|
| 174 |
+
|
| 175 |
+
summary_list = summary_list[::interval]
|
| 176 |
+
# self.data_summary_dict = {k: summary_dict[k] for k in summary_list}
|
| 177 |
+
self.data_mapping = {k: mapping[k] for k in summary_list}
|
| 178 |
+
self.length = len(summary_list)
|
| 179 |
+
self.use_cache_logged = False
|
| 180 |
+
|
| 181 |
+
if self.config.BACKWARD_PREDICTION and self.mode == "training":
|
| 182 |
+
self.real_length = self.length
|
| 183 |
+
self.length = self.length * 2
|
| 184 |
+
|
| 185 |
+
# Convert each string to sequence of codepoints (integer),
|
| 186 |
+
# and then pack them into a numpy array.
|
| 187 |
+
# NOTE(pzh): I forgot why I wrote this. Seems like some issues in multiprocessing.
|
| 188 |
+
|
| 189 |
+
# seqs: A list of np.arrays, each representing the ascii values of a string.
|
| 190 |
+
seqs = [utils.string_to_sequence(s) for s in summary_list]
|
| 191 |
+
|
| 192 |
+
# strings_v: ascii values of all strings, concatenated.
|
| 193 |
+
# strings_o: offsets of each string in strings_v.
|
| 194 |
+
if len(seqs) == 0:
|
| 195 |
+
raise ValueError("No scenarios found in the dataset: {}".format(self.data_dir))
|
| 196 |
+
self.strings_v, self.strings_o = utils.pack_sequences(seqs)
|
| 197 |
+
|
| 198 |
+
# if self.config.DATA.USE_LMDB and self.mode == "training":
|
| 199 |
+
# cache_folder = pathlib.Path(self.data_dir) / "cache"
|
| 200 |
+
# assert cache_folder.is_dir()
|
| 201 |
+
# self.reader = LMDBDatasetReader(cache_folder) # LMDB Reader to load samples
|
| 202 |
+
|
| 203 |
+
from scenestreamer.tokenization import get_tokenizer
|
| 204 |
+
self.tokenizer = get_tokenizer(config=self.config)
|
| 205 |
+
|
| 206 |
+
def __len__(self):
|
| 207 |
+
return self.length
|
| 208 |
+
|
| 209 |
+
def __getitem__(self, index):
|
| 210 |
+
# Unpack the stored codepoints at the correct index into a filename string.
|
| 211 |
+
|
| 212 |
+
use_backward_prediction = False
|
| 213 |
+
if self.config.BACKWARD_PREDICTION and self.mode == "training":
|
| 214 |
+
if index >= self.real_length:
|
| 215 |
+
index = index - self.real_length
|
| 216 |
+
use_backward_prediction = True
|
| 217 |
+
|
| 218 |
+
seq = utils.unpack_sequence(self.strings_v, self.strings_o, index)
|
| 219 |
+
string = utils.sequence_to_string(seq)
|
| 220 |
+
file_name = string
|
| 221 |
+
|
| 222 |
+
try:
|
| 223 |
+
data_dict = self.create_scene_level_data(file_name, index, use_backward_prediction)
|
| 224 |
+
except NoMapFeatureError:
|
| 225 |
+
# This is workaround for Waymo test set where some scenarios do not have map features.
|
| 226 |
+
return self.__getitem__(index + 1)
|
| 227 |
+
|
| 228 |
+
# If self.return_scenario_description is true, data_dict has an extra key [raw_scenario_description] that contains the ScenarioDescription object.
|
| 229 |
+
return data_dict
|
| 230 |
+
|
| 231 |
+
def create_scene_level_data(self, file_name, index, use_backward_prediction=False):
|
| 232 |
+
"""
|
| 233 |
+
Reads a scenario file and preprocesses it.
|
| 234 |
+
"""
|
| 235 |
+
assert not self.config.DATA.USE_LMDB, "LMDB is not supported."
|
| 236 |
+
try:
|
| 237 |
+
# scenario: A ScenarioDescription instance.
|
| 238 |
+
cache = None
|
| 239 |
+
scenario = None
|
| 240 |
+
cache_path = None
|
| 241 |
+
|
| 242 |
+
if self.config.DATA.USE_CACHE:
|
| 243 |
+
cache_folder = pathlib.Path(self.data_dir) / "cache"
|
| 244 |
+
if cache_folder.is_dir() is False:
|
| 245 |
+
cache_folder.mkdir(exist_ok=True)
|
| 246 |
+
|
| 247 |
+
cache_path = pathlib.Path(self.data_dir) / "cache" / file_name
|
| 248 |
+
if cache_path.is_file():
|
| 249 |
+
|
| 250 |
+
try:
|
| 251 |
+
with open(cache_path, "rb") as f:
|
| 252 |
+
cache = pickle.load(f)
|
| 253 |
+
|
| 254 |
+
if self.use_cache_logged is False:
|
| 255 |
+
print("=====================================")
|
| 256 |
+
print("=====================================")
|
| 257 |
+
print("\t*** WARNING ***")
|
| 258 |
+
print("\tYou are using cache files!!!")
|
| 259 |
+
print("\tIn folder: ", cache_folder)
|
| 260 |
+
print("\tThere are ", len(list(cache_folder.glob("*"))), " cache files!!!")
|
| 261 |
+
print("=====================================")
|
| 262 |
+
print("=====================================")
|
| 263 |
+
|
| 264 |
+
self.use_cache_logged = True
|
| 265 |
+
|
| 266 |
+
return cache
|
| 267 |
+
except EOFError as e:
|
| 268 |
+
print(f"Error in reading cache file: {cache_path=}")
|
| 269 |
+
|
| 270 |
+
scenario = read_scenario(
|
| 271 |
+
dataset_path=self.data_dir, mapping=self.data_mapping, scenario_file_name=file_name
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
else:
|
| 275 |
+
scenario = read_scenario(
|
| 276 |
+
dataset_path=self.data_dir, mapping=self.data_mapping, scenario_file_name=file_name
|
| 277 |
+
)
|
| 278 |
+
# print("Cannot find cache file: ", cache_path, "Creating one.")
|
| 279 |
+
|
| 280 |
+
else:
|
| 281 |
+
# if self.config.DATA.USE_LMDB and self.mode == "training":
|
| 282 |
+
# cache = self.reader.load_sample(file_name)
|
| 283 |
+
# else:
|
| 284 |
+
scenario = read_scenario(
|
| 285 |
+
dataset_path=self.data_dir, mapping=self.data_mapping, scenario_file_name=file_name
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
except EOFError as e:
|
| 289 |
+
print(f"{self.data_dir=}, {self.data_mapping=}, {file_name=}")
|
| 290 |
+
raise e
|
| 291 |
+
assert self.mode in ["training", "test"], self.mode
|
| 292 |
+
ret = {}
|
| 293 |
+
|
| 294 |
+
if len(scenario["map_features"]) == 0:
|
| 295 |
+
raise NoMapFeatureError
|
| 296 |
+
|
| 297 |
+
if self.return_scenario_description:
|
| 298 |
+
ret["raw_scenario_description"] = copy.deepcopy(scenario)
|
| 299 |
+
|
| 300 |
+
# TODO: Remove error handling after debugging.
|
| 301 |
+
try:
|
| 302 |
+
preprocessed_scenario_description = preprocess_scenario_description(
|
| 303 |
+
scenario=scenario,
|
| 304 |
+
# cache=cache,
|
| 305 |
+
config=copy.deepcopy(self.config),
|
| 306 |
+
in_evaluation=self.mode != "training",
|
| 307 |
+
keep_all_data=self.config.PREPROCESSING.get("keep_all_data", False),
|
| 308 |
+
backward_prediction=use_backward_prediction,
|
| 309 |
+
tokenizer=self.tokenizer,
|
| 310 |
+
# cache_path=cache_path,
|
| 311 |
+
)
|
| 312 |
+
preprocessed_scenario_description["file_name"] = file_name
|
| 313 |
+
except Exception as e:
|
| 314 |
+
print(f"Error in preprocessing {file_name=}, {index=}, {scenario['id']=}")
|
| 315 |
+
# Ensure that the exception is not swallowed by adding this.
|
| 316 |
+
raise RuntimeError(
|
| 317 |
+
f"{file_name=}, {index=}, {scenario['id']=}. Error in create_scene_level_data: {e}"
|
| 318 |
+
) from e
|
| 319 |
+
|
| 320 |
+
ret.update(preprocessed_scenario_description)
|
| 321 |
+
ret.update({"metadata/scenario_id": scenario['id']})
|
| 322 |
+
|
| 323 |
+
if cache_path is not None:
|
| 324 |
+
with open(cache_path, "wb") as f:
|
| 325 |
+
pickle.dump(ret, f)
|
| 326 |
+
# print("Writing cache file: ", cache_path)
|
| 327 |
+
|
| 328 |
+
return ret
|
| 329 |
+
|
| 330 |
+
def collate_batch(self, batch_list):
|
| 331 |
+
"""
|
| 332 |
+
Output format:
|
| 333 |
+
|
| 334 |
+
agent_feature: [B, T, #agents, D]
|
| 335 |
+
agent_feature_position: [B, T, #agents, 3]
|
| 336 |
+
map_feature: [B, T, #mapfeat, #points, D]
|
| 337 |
+
map_feature_valid_mask: [B, T, #mapfeat, #points]
|
| 338 |
+
map_feature_position: [B, T, #mapfeat, 3]
|
| 339 |
+
"""
|
| 340 |
+
data_dict_sample = batch_list[0]
|
| 341 |
+
|
| 342 |
+
num_map_feat, num_points, _ = data_dict_sample["encoder/map_feature"].shape
|
| 343 |
+
|
| 344 |
+
data_dict = {}
|
| 345 |
+
object_keys = [
|
| 346 |
+
"raw_scenario_description",
|
| 347 |
+
"encoder/track_name",
|
| 348 |
+
"decoder/track_name",
|
| 349 |
+
"eval/track_name",
|
| 350 |
+
# "scenario_id",
|
| 351 |
+
# "in_evaluation"
|
| 352 |
+
] # Keys exempt from padding and tensor conversion.
|
| 353 |
+
|
| 354 |
+
for k in set(data_dict_sample.keys()):
|
| 355 |
+
if k not in object_keys:
|
| 356 |
+
if not isinstance(data_dict_sample[k], np.ndarray):
|
| 357 |
+
assert isinstance(data_dict_sample[k], (int, float, bool, str)), (k, type(data_dict_sample[k]))
|
| 358 |
+
if isinstance(data_dict_sample[k], str):
|
| 359 |
+
data_dict[k] = np.array([b[k] for b in batch_list])
|
| 360 |
+
else:
|
| 361 |
+
data_dict[k] = utils.numpy_to_torch(np.array([b[k] for b in batch_list]))
|
| 362 |
+
continue
|
| 363 |
+
# else:
|
| 364 |
+
# if batch_list[0][k].dtype == np.object:
|
| 365 |
+
# data_dict[k] = [b[k] for b in batch_list]
|
| 366 |
+
# continue
|
| 367 |
+
|
| 368 |
+
val_list = [utils.numpy_to_torch(b[k]) for b in batch_list]
|
| 369 |
+
|
| 370 |
+
# Map features that have vectors' information
|
| 371 |
+
if k in [
|
| 372 |
+
"encoder/map_feature",
|
| 373 |
+
"vis/map_feature",
|
| 374 |
+
"raw/map_feature",
|
| 375 |
+
"encoder/map_feature_valid_mask",
|
| 376 |
+
]:
|
| 377 |
+
data_dict[k] = utils.padding_1st_and_2nd_dim(
|
| 378 |
+
val_list,
|
| 379 |
+
max_1st_dim=self.max_map_features if self.padding_to_max else None,
|
| 380 |
+
max_2nd_dim=self.max_vectors_per_map_feature if self.padding_to_max else None
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
# Map features that have aggregated info from vectors
|
| 384 |
+
elif k in [
|
| 385 |
+
"encoder/map_heading",
|
| 386 |
+
"encoder/map_position",
|
| 387 |
+
"encoder/map_valid_mask",
|
| 388 |
+
]:
|
| 389 |
+
data_dict[k] = utils.padding_1st_dim(
|
| 390 |
+
val_list, max_1st_dim=self.max_map_features if self.padding_to_max else None
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
# Traffic light features that have temporal dim
|
| 394 |
+
elif k in [
|
| 395 |
+
"encoder/traffic_light_feature",
|
| 396 |
+
"encoder/traffic_light_state",
|
| 397 |
+
"encoder/traffic_light_valid_mask",
|
| 398 |
+
]:
|
| 399 |
+
|
| 400 |
+
if self.config.PREPROCESSING.REMOVE_TRAFFIC_LIGHT_STATE:
|
| 401 |
+
data_dict[k] = utils.padding_1st_dim(
|
| 402 |
+
val_list, max_1st_dim=self.max_traffic_lights if self.padding_to_max else None
|
| 403 |
+
)
|
| 404 |
+
else:
|
| 405 |
+
data_dict[k] = utils.padding_1st_and_2nd_dim(
|
| 406 |
+
val_list, max_2nd_dim=self.max_traffic_lights if self.padding_to_max else None
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
# Traffic light features that do not have temporal dim
|
| 410 |
+
elif k in [
|
| 411 |
+
"encoder/traffic_light_position",
|
| 412 |
+
"encoder/traffic_light_heading",
|
| 413 |
+
"encoder/traffic_light_map_id",
|
| 414 |
+
]:
|
| 415 |
+
data_dict[k] = utils.padding_1st_dim(
|
| 416 |
+
val_list, max_1st_dim=self.max_traffic_lights if self.padding_to_max else None
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
# Agent features
|
| 420 |
+
elif k in [
|
| 421 |
+
"encoder/agent_feature",
|
| 422 |
+
"encoder/agent_position",
|
| 423 |
+
"encoder/agent_valid_mask",
|
| 424 |
+
"encoder/agent_heading",
|
| 425 |
+
"encoder/agent_velocity",
|
| 426 |
+
"decoder/modeled_agent_position",
|
| 427 |
+
"decoder/modeled_agent_heading",
|
| 428 |
+
"decoder/modeled_agent_velocity",
|
| 429 |
+
"decoder/modeled_agent_delta",
|
| 430 |
+
]:
|
| 431 |
+
data_dict[k] = utils.padding_1st_and_2nd_dim(
|
| 432 |
+
val_list, max_2nd_dim=self.max_agents if self.padding_to_max else None
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
# Other data that does not pass the model or does not need regular shapes
|
| 436 |
+
elif k in [
|
| 437 |
+
# "encoder/modeled_agent_id",
|
| 438 |
+
# "action_label/labeled_agent_id",
|
| 439 |
+
"metadata/map_center", # "decoder/input_step",
|
| 440 |
+
# "decoder/input_intra_step",
|
| 441 |
+
"encoder/current_agent_heading",
|
| 442 |
+
"decoder/current_agent_heading",
|
| 443 |
+
"encoder/current_agent_shape",
|
| 444 |
+
"decoder/current_agent_shape",
|
| 445 |
+
"eval/current_agent_heading",
|
| 446 |
+
"encoder/current_agent_valid_mask",
|
| 447 |
+
"decoder/current_agent_valid_mask",
|
| 448 |
+
"eval/current_agent_valid_mask",
|
| 449 |
+
# "decoder/current_agent_valid_mask", #
|
| 450 |
+
# "decoder/modeled_agent_indices",
|
| 451 |
+
# For gen model:
|
| 452 |
+
# "decoder/input_token_valid_mask",
|
| 453 |
+
# "decoder/should_predict",
|
| 454 |
+
# "decoder/is_gt",
|
| 455 |
+
# "eval/should_predict_motion",
|
| 456 |
+
]:
|
| 457 |
+
data_dict[k] = utils.padding_1st_dim(val_list)
|
| 458 |
+
|
| 459 |
+
elif k in [
|
| 460 |
+
"decoder/input_action_valid_mask",
|
| 461 |
+
"encoder/current_agent_position",
|
| 462 |
+
"decoder/current_agent_position",
|
| 463 |
+
"encoder/current_agent_velocity",
|
| 464 |
+
"decoder/current_agent_velocity",
|
| 465 |
+
"decoder/target_action_valid_mask",
|
| 466 |
+
#"decoder/future_agent_position",
|
| 467 |
+
#"decoder/future_agent_heading",
|
| 468 |
+
#"decoder/future_agent_valid_mask",
|
| 469 |
+
#"decoder/future_agent_velocity",
|
| 470 |
+
#"encoder/future_agent_position",
|
| 471 |
+
#"encoder/future_agent_heading",
|
| 472 |
+
#"encoder/future_agent_valid_mask",
|
| 473 |
+
#"encoder/future_agent_velocity",
|
| 474 |
+
"decoder/agent_position",
|
| 475 |
+
"decoder/agent_heading",
|
| 476 |
+
"decoder/agent_velocity",
|
| 477 |
+
"decoder/agent_valid_mask",
|
| 478 |
+
"eval/agent_velocity",
|
| 479 |
+
"eval/agent_heading",
|
| 480 |
+
"eval/agent_position",
|
| 481 |
+
"eval/agent_valid_mask",
|
| 482 |
+
"encoder/agent_shape",
|
| 483 |
+
"decoder/agent_shape",
|
| 484 |
+
"eval/agent_shape", # "decoder/target_valid_mask",
|
| 485 |
+
"decoder/input_agent_motion",
|
| 486 |
+
"decoder/target_agent_motion",
|
| 487 |
+
"decoder/dest_map_index_valid_mask",
|
| 488 |
+
]:
|
| 489 |
+
data_dict[k] = utils.padding_1st_and_2nd_dim(val_list)
|
| 490 |
+
|
| 491 |
+
elif k in [
|
| 492 |
+
"encoder/agent_type",
|
| 493 |
+
"decoder/agent_type",
|
| 494 |
+
"encoder/modeled_agent_type",
|
| 495 |
+
"eval/agent_type", # "eval/raw_agent_name",
|
| 496 |
+
"encoder/object_of_interest_name",
|
| 497 |
+
"decoder/object_of_interest_name",
|
| 498 |
+
"metadata/sdc_name", # "eval/modeled_agent_id",
|
| 499 |
+
"encoder/object_of_interest_id",
|
| 500 |
+
"decoder/object_of_interest_id",
|
| 501 |
+
"encoder/modeled_agent_id", # "decoder/modeled_agent_id",
|
| 502 |
+
"encoder/agent_id",
|
| 503 |
+
"decoder/agent_id",
|
| 504 |
+
"decoder/labeled_agent_id",
|
| 505 |
+
"decoder/label_turning",
|
| 506 |
+
"decoder/label_acceleration",
|
| 507 |
+
"decoder/label_safety",
|
| 508 |
+
# For gen model:
|
| 509 |
+
# "decoder/input_token_id",
|
| 510 |
+
# "decoder/causal_mask_offset",
|
| 511 |
+
]:
|
| 512 |
+
data_dict[k] = utils.padding_1st_dim(val_list, fill=-1)
|
| 513 |
+
|
| 514 |
+
elif k in [
|
| 515 |
+
"decoder/dest_map_index",
|
| 516 |
+
"decoder/dest_map_index_gt",
|
| 517 |
+
]:
|
| 518 |
+
data_dict[k] = utils.padding_1st_and_2nd_dim(val_list, fill=-1)
|
| 519 |
+
|
| 520 |
+
elif k in [
|
| 521 |
+
"decoder/input_action",
|
| 522 |
+
"decoder/target_action",
|
| 523 |
+
"decoder/input_action_for_trafficgen",
|
| 524 |
+
|
| 525 |
+
"decoder/current_agent_shape_for_trafficgen",
|
| 526 |
+
"decoder/modeled_agent_heading_for_trafficgen",
|
| 527 |
+
"decoder/modeled_agent_position_for_trafficgen",
|
| 528 |
+
"decoder/modeled_agent_velocity_for_trafficgen",
|
| 529 |
+
"decoder/input_action_valid_mask_for_trafficgen",
|
| 530 |
+
"decoder/modeled_agent_delta_for_trafficgen",
|
| 531 |
+
"decoder/input_action_feature_for_trafficgen",
|
| 532 |
+
"decoder/target_offset_for_trafficgen",
|
| 533 |
+
"decoder/input_offset_for_trafficgen",
|
| 534 |
+
"decoder/agent_id_for_trafficgen",
|
| 535 |
+
"decoder/trafficgen_position",
|
| 536 |
+
"decoder/trafficgen_heading",
|
| 537 |
+
"decoder/agent_type_for_trafficgen",
|
| 538 |
+
]:
|
| 539 |
+
data_dict[k] = utils.padding_all_dims(val_list, fill=-1)
|
| 540 |
+
|
| 541 |
+
elif k in object_keys:
|
| 542 |
+
# Passthrough: Have the data_dict[object] contain a list of objects.
|
| 543 |
+
data_dict[k] = [b[k] for b in batch_list]
|
| 544 |
+
|
| 545 |
+
elif k in [
|
| 546 |
+
"encoder/sdc_index",
|
| 547 |
+
]:
|
| 548 |
+
pass
|
| 549 |
+
|
| 550 |
+
else:
|
| 551 |
+
raise ValueError("Unknown key: {}".format(k))
|
| 552 |
+
|
| 553 |
+
return data_dict
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
if hydra is not None:
|
| 557 |
+
@hydra.main(version_base=None, config_path=str(REPO_ROOT / "cfgs"), config_name="1009_safety_action_debug.yaml")
|
| 558 |
+
def debug(config):
|
| 559 |
+
test_dataset = SceneStreamerDataset(config, "training")
|
| 560 |
+
ddd = iter(test_dataset)
|
| 561 |
+
count = 0
|
| 562 |
+
buggy_count = 0
|
| 563 |
+
while True:
|
| 564 |
+
if count == 3:
|
| 565 |
+
return
|
| 566 |
+
try:
|
| 567 |
+
data = next(ddd)
|
| 568 |
+
count += 1
|
| 569 |
+
|
| 570 |
+
assert data["decoder/label_safety"][data["decoder/labeled_agent_id"]].sum() > 1
|
| 571 |
+
|
| 572 |
+
except StopIteration:
|
| 573 |
+
break
|
| 574 |
+
|
| 575 |
+
except AssertionError:
|
| 576 |
+
print("ni collision")
|
| 577 |
+
buggy_count += 1
|
| 578 |
+
print("scenario_id", data["scenario_id"])
|
| 579 |
+
print("data['decoder/label_safety']", data["decoder/label_safety"])
|
| 580 |
+
print("data['decoder/labeled_agent_id']", data["decoder/labeled_agent_id"])
|
| 581 |
+
print("track_name", data["decoder/track_name"][data["decoder/labeled_agent_id"]])
|
| 582 |
+
|
| 583 |
+
print("buggy_count:", buggy_count)
|
| 584 |
+
print("count", count)
|
| 585 |
+
print("End")
|
| 586 |
+
|
| 587 |
+
@hydra.main(version_base=None, config_path=str(REPO_ROOT / "cfgs"), config_name="motion_default.yaml")
|
| 588 |
+
def read_traffic_light_state(config):
|
| 589 |
+
test_dataset = SceneStreamerDataset(config, "training")
|
| 590 |
+
|
| 591 |
+
total_tl = 0
|
| 592 |
+
total_green = 0
|
| 593 |
+
total_yellow = 0
|
| 594 |
+
total_red = 0
|
| 595 |
+
total_unknown = 0
|
| 596 |
+
total_mix = 0
|
| 597 |
+
import tqdm
|
| 598 |
+
|
| 599 |
+
for data in tqdm.tqdm(test_dataset):
|
| 600 |
+
tl = data["encoder/traffic_light_feature"]
|
| 601 |
+
mask = data["encoder/traffic_light_valid_mask"]
|
| 602 |
+
|
| 603 |
+
for i in range(tl.shape[1]):
|
| 604 |
+
if mask[:, i].any():
|
| 605 |
+
is_green = tl[:, i, 3].astype(bool).any()
|
| 606 |
+
is_yellow = tl[:, i, 4].astype(bool).any()
|
| 607 |
+
is_red = tl[:, i, 5].astype(bool).any()
|
| 608 |
+
is_unknown = tl[:, i, 6].astype(bool).any()
|
| 609 |
+
|
| 610 |
+
total_tl += 1
|
| 611 |
+
total_green += is_green
|
| 612 |
+
total_yellow += is_yellow
|
| 613 |
+
total_red += is_red
|
| 614 |
+
total_unknown += is_unknown
|
| 615 |
+
total_mix += (is_green and is_yellow) or (is_green and is_red) or (is_yellow and is_red)
|
| 616 |
+
|
| 617 |
+
print("total_tl:", total_tl)
|
| 618 |
+
print("total_green: {}\t{:.4f}".format(total_green, total_green / total_tl))
|
| 619 |
+
print("total_yellow: {}\t{:.4f}".format(total_yellow, total_yellow / total_tl))
|
| 620 |
+
print("total_red: {}\t{:.4f}".format(total_red, total_red / total_tl))
|
| 621 |
+
print("total_unknown: {}\t{:.4f}".format(total_unknown, total_unknown / total_tl))
|
| 622 |
+
print("total_mix: {}\t{:.4f}".format(total_mix, total_mix / total_tl))
|
| 623 |
+
else:
|
| 624 |
+
debug = None
|
| 625 |
+
read_traffic_light_state = None
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
if __name__ == '__main__':
|
| 629 |
+
# debug()
|
| 630 |
+
read_traffic_light_state()
|
scenestreamer/dataset/make_lmdb.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Only the TRAINING_DATA_DIR will be used in the code below.
|
| 3 |
+
Usage:
|
| 4 |
+
|
| 5 |
+
python -m scenestreamer.dataset.make_lmdb \
|
| 6 |
+
--config-name="1024_gpt" DATA.TEST_DATA_DIR='data/20scenarios' \
|
| 7 |
+
DATA.TRAINING_DATA_DIR="/data_zhenghao/datasets/scenarionet/CAT_waymo_hybrid/"
|
| 8 |
+
|
| 9 |
+
"""
|
| 10 |
+
import json
|
| 11 |
+
import os
|
| 12 |
+
import pathlib
|
| 13 |
+
import pickle
|
| 14 |
+
import multiprocessing as mp
|
| 15 |
+
from functools import partial
|
| 16 |
+
import tqdm
|
| 17 |
+
import hydra
|
| 18 |
+
import lmdb
|
| 19 |
+
import omegaconf
|
| 20 |
+
import tqdm
|
| 21 |
+
|
| 22 |
+
from scenestreamer.dataset.dataset import SceneStreamerDataset
|
| 23 |
+
|
| 24 |
+
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class LMDBBulkWriter:
|
| 28 |
+
def __init__(self, base_path, max_size=1e9):
|
| 29 |
+
"""
|
| 30 |
+
Initializes the LMDBBulkWriter to save all data in batches, with map_size for each LMDB file.
|
| 31 |
+
Args:
|
| 32 |
+
base_path: Directory path to save LMDB files.
|
| 33 |
+
max_size: Maximum size of each LMDB file in bytes.
|
| 34 |
+
"""
|
| 35 |
+
self.base_path = base_path
|
| 36 |
+
# Create the cache directory if it doesn't exist
|
| 37 |
+
os.makedirs(self.base_path, exist_ok=True)
|
| 38 |
+
|
| 39 |
+
self.max_size = int(max_size) # Set the max LMDB file size (e.g., 1 GB)
|
| 40 |
+
self.current_db_index = 0
|
| 41 |
+
self.lookup = {} # Lookup table to track which LMDB file stores which sample
|
| 42 |
+
self.current_db = self._open_new_lmdb(self.current_db_index)
|
| 43 |
+
self.per_shard_size = 0
|
| 44 |
+
|
| 45 |
+
self.sample_buffer = []
|
| 46 |
+
|
| 47 |
+
def _open_new_lmdb(self, db_index):
|
| 48 |
+
"""Opens a new LMDB file for saving samples."""
|
| 49 |
+
db_path = f"{self.base_path}/data_{db_index}.lmdb"
|
| 50 |
+
return lmdb.open(db_path, map_size=self.max_size)
|
| 51 |
+
|
| 52 |
+
def _save_a_batch(self):
|
| 53 |
+
|
| 54 |
+
try:
|
| 55 |
+
# Commit the transaction if we have reached the commit interval
|
| 56 |
+
# if (not hasattr(self, 'txn')) or (self.txn is None):
|
| 57 |
+
# self.txn = self.current_db.begin(write=True) # Start a new transaction
|
| 58 |
+
#
|
| 59 |
+
#
|
| 60 |
+
# for key, data in self.sample_buffer:
|
| 61 |
+
# self.txn.put(key.encode('ascii'), pickle.dumps(data))
|
| 62 |
+
# self.lookup[key] = f"data_{self.current_db_index}.lmdb"
|
| 63 |
+
#
|
| 64 |
+
# if hasattr(self, 'txn') and self.txn:
|
| 65 |
+
# self.txn.commit() # Commit the transaction
|
| 66 |
+
print(f"Saving {len(self.sample_buffer)} samples to data_{self.current_db_index}.lmdb")
|
| 67 |
+
with self.current_db.begin(write=True) as txn:
|
| 68 |
+
for key, data in self.sample_buffer:
|
| 69 |
+
txn.put(key.encode('ascii'), pickle.dumps(data))
|
| 70 |
+
self.lookup[key] = f"data_{self.current_db_index}.lmdb"
|
| 71 |
+
|
| 72 |
+
self.sample_buffer.clear()
|
| 73 |
+
|
| 74 |
+
except lmdb.MapFullError:
|
| 75 |
+
|
| 76 |
+
# If current LMDB file is full, create a new one and retry saving
|
| 77 |
+
self.current_db.close()
|
| 78 |
+
self.current_db_index += 1
|
| 79 |
+
print(f"Creating new LMDB file: data_{self.current_db_index}.lmdb (size: {self.per_shard_size})")
|
| 80 |
+
self.current_db = self._open_new_lmdb(self.current_db_index)
|
| 81 |
+
self._save_a_batch()
|
| 82 |
+
self.per_shard_size = 0
|
| 83 |
+
|
| 84 |
+
def save_sample(self, key, data):
|
| 85 |
+
"""Saves a sample to the current LMDB file, switching to a new file if necessary."""
|
| 86 |
+
# Batch writes into a single transaction
|
| 87 |
+
if self.per_shard_size % 100 == 0:
|
| 88 |
+
self._save_a_batch()
|
| 89 |
+
self.sample_buffer.append((key, data))
|
| 90 |
+
self.per_shard_size += 1
|
| 91 |
+
|
| 92 |
+
def close(self):
|
| 93 |
+
self._save_a_batch()
|
| 94 |
+
"""Closes the LMDB environment and saves the lookup table as a JSON file."""
|
| 95 |
+
self.current_db.close()
|
| 96 |
+
# Save the lookup table to track the LMDB file where each sample is stored
|
| 97 |
+
with open(f"{self.base_path}/lookup.json", "w") as f:
|
| 98 |
+
json.dump(self.lookup, f)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def preprocess_and_queue_worker(worker_id, config, indices, queue):
|
| 102 |
+
"""
|
| 103 |
+
This function runs in each worker to preprocess samples and send them to the write queue.
|
| 104 |
+
The writer process will handle writing to LMDB.
|
| 105 |
+
"""
|
| 106 |
+
print(f"Worker {worker_id} started.")
|
| 107 |
+
dataset = SceneStreamerDataset(config, "training")
|
| 108 |
+
|
| 109 |
+
print(f"Worker {worker_id} has {len(dataset)} samples.")
|
| 110 |
+
|
| 111 |
+
# Process and queue each sample assigned to this worker
|
| 112 |
+
if worker_id == 0:
|
| 113 |
+
pbar = tqdm.tqdm(indices, desc="Worker %d" % worker_id)
|
| 114 |
+
else:
|
| 115 |
+
pbar = indices
|
| 116 |
+
print(f"Worker {worker_id} has {len(indices)} samples.")
|
| 117 |
+
|
| 118 |
+
for i in pbar:
|
| 119 |
+
sample = dataset[i] # Access the sample using its index
|
| 120 |
+
|
| 121 |
+
# Simulate some preprocessing (replace with actual preprocessing logic)
|
| 122 |
+
file_name, processed_sample = sample["file_name"], sample
|
| 123 |
+
|
| 124 |
+
# Put the preprocessed sample into the queue to be written by the writer process
|
| 125 |
+
print(f"Worker {worker_id} processed {file_name}")
|
| 126 |
+
queue.put((file_name, processed_sample))
|
| 127 |
+
|
| 128 |
+
# Signal that this worker is done
|
| 129 |
+
# queue.put(None) # 'None' signals that the worker is done
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def write_process(queue, base_path, max_size):
|
| 133 |
+
"""
|
| 134 |
+
The write process receives samples from the queue and writes them to the LMDB environment.
|
| 135 |
+
"""
|
| 136 |
+
writer = LMDBBulkWriter(base_path=base_path, max_size=max_size)
|
| 137 |
+
print("Writer process started.")
|
| 138 |
+
|
| 139 |
+
while True:
|
| 140 |
+
|
| 141 |
+
# Blocking if no data is available
|
| 142 |
+
data = queue.get()
|
| 143 |
+
|
| 144 |
+
if data == 100:
|
| 145 |
+
print("Received 100, stopping writer process.")
|
| 146 |
+
# If 'None' is received, this indicates that a worker has finished
|
| 147 |
+
break
|
| 148 |
+
|
| 149 |
+
if data is None:
|
| 150 |
+
print("Received None, stopping writer process.")
|
| 151 |
+
continue
|
| 152 |
+
|
| 153 |
+
file_name, processed_sample = data
|
| 154 |
+
print(f"Saved {file_name} to LMDB")
|
| 155 |
+
writer.save_sample(file_name, processed_sample)
|
| 156 |
+
|
| 157 |
+
# Close the writer once all workers are done
|
| 158 |
+
writer.close()
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
@hydra.main(version_base=None, config_path=str(REPO_ROOT / "cfgs"), config_name="1024_gpt.yaml")
|
| 162 |
+
def make_lmdb(config):
|
| 163 |
+
omegaconf.OmegaConf.set_struct(config, False)
|
| 164 |
+
omegaconf.OmegaConf.set_struct(config, True)
|
| 165 |
+
|
| 166 |
+
dataset = SceneStreamerDataset(config, "training")
|
| 167 |
+
folder = pathlib.Path(dataset.data_dir)
|
| 168 |
+
folder = folder / "cache"
|
| 169 |
+
folder.mkdir(parents=True, exist_ok=False)
|
| 170 |
+
|
| 171 |
+
# Initialize the LMDBBulkWriter
|
| 172 |
+
print("Saving data to LMDB folder:", folder.absolute())
|
| 173 |
+
|
| 174 |
+
# num_workers = mp.cpu_count()
|
| 175 |
+
num_workers = 2
|
| 176 |
+
|
| 177 |
+
dataset_size = len(dataset)
|
| 178 |
+
indices = list(range(dataset_size))
|
| 179 |
+
chunk_size = dataset_size // num_workers
|
| 180 |
+
|
| 181 |
+
# Split the indices into chunks, one for each worker
|
| 182 |
+
chunked_indices = [indices[i * chunk_size:(i + 1) * chunk_size] for i in range(num_workers)]
|
| 183 |
+
# The final chunk may have more samples if the dataset size is not divisible by the number of workers.
|
| 184 |
+
chunked_indices[0].extend(indices[num_workers * chunk_size:])
|
| 185 |
+
|
| 186 |
+
# Create a multiprocessing queue
|
| 187 |
+
queue = mp.Queue()
|
| 188 |
+
|
| 189 |
+
# Create and start the writer process
|
| 190 |
+
writer_process = mp.Process(target=write_process, args=(queue, folder, 1e10))
|
| 191 |
+
|
| 192 |
+
writer_process.start()
|
| 193 |
+
|
| 194 |
+
# Create a multiprocessing pool for parallel processing (preprocessing)
|
| 195 |
+
pool = mp.Pool(num_workers)
|
| 196 |
+
|
| 197 |
+
results = []
|
| 198 |
+
# Start each worker process, passing its chunk of indices
|
| 199 |
+
for worker_id, worker_indices in enumerate(chunked_indices):
|
| 200 |
+
print(f"Starting worker {worker_id} with {len(worker_indices)} samples.")
|
| 201 |
+
result = pool.apply_async(preprocess_and_queue_worker, args=(worker_id, config, worker_indices, queue))
|
| 202 |
+
results.append(result)
|
| 203 |
+
|
| 204 |
+
# Wait for all worker processes to complete
|
| 205 |
+
# for result in results:
|
| 206 |
+
# result.get() # This will block until the worker completes its task
|
| 207 |
+
# preprocess_and_queue_worker(0, config, chunked_indices[0], queue)
|
| 208 |
+
pool.close()
|
| 209 |
+
print("Waiting for workers to finish...")
|
| 210 |
+
pool.join()
|
| 211 |
+
print("All workers finished.")
|
| 212 |
+
|
| 213 |
+
# Signal the writer process to stop (send 'None' once all workers are done)
|
| 214 |
+
queue.put(100)
|
| 215 |
+
# Wait for the writer process to finish
|
| 216 |
+
writer_process.join()
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
@hydra.main(version_base=None, config_path=str(REPO_ROOT / "cfgs"), config_name="motion_default.yaml")
|
| 220 |
+
def debug(config):
|
| 221 |
+
omegaconf.OmegaConf.set_struct(config, False)
|
| 222 |
+
omegaconf.OmegaConf.set_struct(config, True)
|
| 223 |
+
dataset = SceneStreamerDataset(config, "training")
|
| 224 |
+
folder = pathlib.Path(dataset.data_dir)
|
| 225 |
+
folder = folder / "cache"
|
| 226 |
+
folder.mkdir(parents=True, exist_ok=True)
|
| 227 |
+
for i, sample in enumerate(tqdm.tqdm(dataset, total=len(dataset), desc="Scenarios")):
|
| 228 |
+
file_name = sample["file_name"]
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
if __name__ == '__main__':
|
| 232 |
+
make_lmdb()
|
| 233 |
+
# debug()
|
scenestreamer/dataset/preprocess_action_label.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from shapely.geometry import Polygon
|
| 3 |
+
|
| 4 |
+
from scenestreamer.utils import utils
|
| 5 |
+
|
| 6 |
+
INVALID_VALUE = -10000
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class TurnAction:
|
| 10 |
+
STOP = 0
|
| 11 |
+
KEEP_STRAIGHT = 1
|
| 12 |
+
TURN_LEFT = 2
|
| 13 |
+
TURN_RIGHT = 3
|
| 14 |
+
U_TURN = 4
|
| 15 |
+
|
| 16 |
+
num_actions = 5
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class AccelerationAction:
|
| 20 |
+
STOP = 0
|
| 21 |
+
KEEP_SPEED = 1
|
| 22 |
+
SPEED_UP = 2
|
| 23 |
+
SLOW_DOWN = 3
|
| 24 |
+
|
| 25 |
+
num_actions = 4
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class SafetyAction:
|
| 29 |
+
SAFE = 0
|
| 30 |
+
COLLISION = 1
|
| 31 |
+
num_actions = 2
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def cal_polygon_contour(x, y, theta, width, length):
|
| 35 |
+
|
| 36 |
+
left_front_x = x + 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta)
|
| 37 |
+
left_front_y = y + 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta)
|
| 38 |
+
left_front = np.column_stack((left_front_x, left_front_y))
|
| 39 |
+
|
| 40 |
+
right_front_x = x + 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta)
|
| 41 |
+
right_front_y = y + 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta)
|
| 42 |
+
right_front = np.column_stack((right_front_x, right_front_y))
|
| 43 |
+
|
| 44 |
+
right_back_x = x - 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta)
|
| 45 |
+
right_back_y = y - 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta)
|
| 46 |
+
right_back = np.column_stack((right_back_x, right_back_y))
|
| 47 |
+
|
| 48 |
+
left_back_x = x - 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta)
|
| 49 |
+
left_back_y = y - 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta)
|
| 50 |
+
left_back = np.column_stack((left_back_x, left_back_y))
|
| 51 |
+
|
| 52 |
+
polygon_contour = np.concatenate(
|
| 53 |
+
(left_front[:, None, :], right_front[:, None, :], right_back[:, None, :], left_back[:, None, :]), axis=1
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
return polygon_contour
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def detect_collision(contour_list1, mask1, contour_list2, mask2):
|
| 60 |
+
collision_detected = []
|
| 61 |
+
assert len(contour_list1) == len(contour_list2)
|
| 62 |
+
|
| 63 |
+
for i in range(len(contour_list1)):
|
| 64 |
+
if mask1[i] and mask2[i]:
|
| 65 |
+
poly1 = Polygon(contour_list1[i])
|
| 66 |
+
poly2 = Polygon(contour_list2[i])
|
| 67 |
+
|
| 68 |
+
if poly1.intersects(poly2):
|
| 69 |
+
collision_detected.append(True)
|
| 70 |
+
else:
|
| 71 |
+
collision_detected.append(False)
|
| 72 |
+
else:
|
| 73 |
+
collision_detected.append(False)
|
| 74 |
+
|
| 75 |
+
return collision_detected
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def get_direction_action_from_trajectory_batch(traj, mask, dt=0.1, ooi=None):
|
| 79 |
+
U_TURN_DEG = 115
|
| 80 |
+
LEFT_TURN_DEG = 25
|
| 81 |
+
RIGHT_TURN_DEG = -25
|
| 82 |
+
STOP_SPEED = 0.06
|
| 83 |
+
|
| 84 |
+
assert traj.ndim == 3
|
| 85 |
+
traj_diff = traj[1:] - traj[:-1]
|
| 86 |
+
mask_diff = mask[1:] & mask[:-1]
|
| 87 |
+
|
| 88 |
+
displacement = np.linalg.norm(traj_diff, axis=-1)
|
| 89 |
+
|
| 90 |
+
mask_diff_stop = mask_diff & (displacement > 0.1)
|
| 91 |
+
|
| 92 |
+
pred_angles = np.arctan2(traj_diff[..., 1], traj_diff[..., 0])
|
| 93 |
+
pred_angles_diff = utils.wrap_to_pi(pred_angles[1:] - pred_angles[:-1])
|
| 94 |
+
|
| 95 |
+
# It's meaning less to compute heading for a stopped vehicle. So mask them out!
|
| 96 |
+
mask_diff_diff = mask_diff_stop[1:] & mask_diff_stop[:-1]
|
| 97 |
+
# Note that we should not wrap to pi here because the sign is important.
|
| 98 |
+
accumulated_heading_change_rad = (pred_angles_diff * mask_diff_diff).sum(axis=0)
|
| 99 |
+
accumulated_heading_change_deg = np.degrees(accumulated_heading_change_rad)
|
| 100 |
+
|
| 101 |
+
# print("accumulated_heading_change_deg: ", list(zip(ooi, accumulated_heading_change_deg)))
|
| 102 |
+
|
| 103 |
+
speed = displacement / dt
|
| 104 |
+
avg_speed = utils.masked_average_numpy(speed, mask_diff, dim=0)
|
| 105 |
+
|
| 106 |
+
actions = np.zeros(accumulated_heading_change_deg.shape, dtype=int)
|
| 107 |
+
actions.fill(TurnAction.KEEP_STRAIGHT)
|
| 108 |
+
actions[accumulated_heading_change_deg > LEFT_TURN_DEG] = TurnAction.TURN_LEFT
|
| 109 |
+
actions[accumulated_heading_change_deg < RIGHT_TURN_DEG] = TurnAction.TURN_RIGHT
|
| 110 |
+
actions[accumulated_heading_change_deg > U_TURN_DEG] = TurnAction.U_TURN
|
| 111 |
+
actions[accumulated_heading_change_deg < -U_TURN_DEG] = TurnAction.U_TURN
|
| 112 |
+
actions[avg_speed < STOP_SPEED] = TurnAction.STOP
|
| 113 |
+
return actions
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def get_acce_action_from_trajectory_batch(batch_trajs, mask, ooi=None, dt=0.1):
|
| 117 |
+
|
| 118 |
+
SPEEDUP_ACCEL = 0.3
|
| 119 |
+
SPEEDDOWN_ACCEL = -0.3
|
| 120 |
+
STOP_SPEED = 0.06
|
| 121 |
+
|
| 122 |
+
traj_diff = batch_trajs[1:] - batch_trajs[:-1] # (T, A, 2)
|
| 123 |
+
mask_diff = mask[1:] & mask[:-1] # (T, A)
|
| 124 |
+
|
| 125 |
+
speed = np.linalg.norm(traj_diff, axis=-1) / dt # (T, A)
|
| 126 |
+
|
| 127 |
+
speed_change = speed[1:] - speed[:-1]
|
| 128 |
+
mask_diff_diff = mask_diff[1:] & mask_diff[:-1]
|
| 129 |
+
|
| 130 |
+
absolute_avg_speed = utils.masked_average_numpy(speed, mask_diff, dim=0)
|
| 131 |
+
|
| 132 |
+
accumulated_speed_change = (speed_change * mask_diff_diff).sum(0)
|
| 133 |
+
|
| 134 |
+
init_speed_ind = mask_diff.argmax(axis=0)
|
| 135 |
+
init_speed = np.take_along_axis(speed, init_speed_ind[None, :], axis=0)[0]
|
| 136 |
+
|
| 137 |
+
speed_change_ratio = accumulated_speed_change / np.maximum(init_speed, STOP_SPEED)
|
| 138 |
+
|
| 139 |
+
# print("speed_change_ratio: ", list(zip(ooi, speed_change_ratio)))
|
| 140 |
+
|
| 141 |
+
actions = np.zeros(speed_change_ratio.shape, dtype=int)
|
| 142 |
+
|
| 143 |
+
actions.fill(AccelerationAction.KEEP_SPEED)
|
| 144 |
+
actions[speed_change_ratio > SPEEDUP_ACCEL] = AccelerationAction.SPEED_UP
|
| 145 |
+
actions[speed_change_ratio < SPEEDDOWN_ACCEL] = AccelerationAction.SLOW_DOWN
|
| 146 |
+
actions[absolute_avg_speed <= STOP_SPEED] = AccelerationAction.STOP # if stop
|
| 147 |
+
|
| 148 |
+
return actions
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def get_safety_action_from_sdc_adv(data_dict, adv_id, sdc_id):
|
| 152 |
+
|
| 153 |
+
contours = []
|
| 154 |
+
for agent_id in [adv_id, sdc_id]:
|
| 155 |
+
traj = data_dict["decoder/agent_position"][:91, agent_id, :] # (91, 3)
|
| 156 |
+
length = data_dict["decoder/agent_shape"][:91, agent_id, 0]
|
| 157 |
+
width = data_dict["decoder/agent_shape"][:91, agent_id, 1]
|
| 158 |
+
theta = data_dict['decoder/agent_heading'][:91, agent_id] # (91, ) # in pi
|
| 159 |
+
mask = data_dict['decoder/agent_valid_mask'][:91, agent_id] # (91,)
|
| 160 |
+
|
| 161 |
+
poly = cal_polygon_contour(traj[:, 0], traj[:, 1], theta, width, length)
|
| 162 |
+
contours.append(poly)
|
| 163 |
+
|
| 164 |
+
sdc_mask = data_dict['decoder/agent_valid_mask'][:, sdc_id] # (91,)
|
| 165 |
+
adv_mask = data_dict['decoder/agent_valid_mask'][:, adv_id]
|
| 166 |
+
adv_contour = contours[0]
|
| 167 |
+
sdc_contour = contours[1]
|
| 168 |
+
|
| 169 |
+
collision_detected = detect_collision(adv_contour, adv_mask, sdc_contour, sdc_mask)
|
| 170 |
+
|
| 171 |
+
# instead of loading a dict which saves all collision scenario, we could simply detect all agents' potential collision
|
| 172 |
+
return collision_detected
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def get_safety_action_from_trajectory_batch(data_dict, track_agent_indicies):
|
| 176 |
+
|
| 177 |
+
safety_actions = np.zeros((track_agent_indicies.shape[0], ), dtype=int) # plus sdc
|
| 178 |
+
|
| 179 |
+
contours = []
|
| 180 |
+
for agent1_id in track_agent_indicies:
|
| 181 |
+
traj = data_dict["decoder/agent_position"][:, agent1_id, :] # (91, 3)
|
| 182 |
+
length = data_dict["decoder/agent_shape"][:, agent1_id, 0]
|
| 183 |
+
width = data_dict["decoder/agent_shape"][:, agent1_id, 1]
|
| 184 |
+
theta = data_dict['decoder/agent_heading'][:, agent1_id] # (91, ) # in pi
|
| 185 |
+
mask = data_dict['decoder/agent_valid_mask'][:, agent1_id] # (91,)
|
| 186 |
+
poly = cal_polygon_contour(traj[:, 0], traj[:, 1], theta, width, length)
|
| 187 |
+
contours.append(poly)
|
| 188 |
+
|
| 189 |
+
for i in range(track_agent_indicies.shape[0] - 1):
|
| 190 |
+
for j in range(i + 1, track_agent_indicies.shape[0]):
|
| 191 |
+
mask_1 = data_dict['decoder/agent_valid_mask'][:, track_agent_indicies[i]] # (91,)
|
| 192 |
+
mask_2 = data_dict['decoder/agent_valid_mask'][:, track_agent_indicies[j]]
|
| 193 |
+
collision_detected = detect_collision(contours[i], mask_1, contours[j], mask_2)
|
| 194 |
+
|
| 195 |
+
if any(collision_detected):
|
| 196 |
+
# print(f"Collision between {i} and {j} happen at step: {np.array(collision_detected).nonzero()}")
|
| 197 |
+
safety_actions[i] = 1 # Label collisions for OOIs now. Later we will build a larger dict.
|
| 198 |
+
safety_actions[j] = 1
|
| 199 |
+
|
| 200 |
+
# instead of loading a dict which saves all collision scenario, we could simply detect all agents' potential collision
|
| 201 |
+
return safety_actions
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def prepare_action_label(*, data_dict, dt, mask_probability, config):
|
| 205 |
+
"""
|
| 206 |
+
mask_probability: the probability of masking the label. Should be around 0.05 or 0.1. Can't be too high.
|
| 207 |
+
"""
|
| 208 |
+
ooi_ind = data_dict["decoder/labeled_agent_id"]
|
| 209 |
+
ooi_pos = utils.extract_data_by_agent_indices(data_dict["decoder/agent_position"], ooi_ind, agent_dim=1)[..., :2]
|
| 210 |
+
ooi_valid = utils.extract_data_by_agent_indices(
|
| 211 |
+
data_dict["decoder/agent_valid_mask"], ooi_ind, agent_dim=1
|
| 212 |
+
) # (T, A)
|
| 213 |
+
|
| 214 |
+
# TODO: hardcoded here for now and we assume you can access GT trajectory. This won't work with test dataset.
|
| 215 |
+
assert ooi_pos.shape[0] == 91
|
| 216 |
+
assert ooi_valid.shape[0] == 91
|
| 217 |
+
|
| 218 |
+
# get the degree, acceleration, speed
|
| 219 |
+
turn_actions = get_direction_action_from_trajectory_batch(traj=ooi_pos, mask=ooi_valid, dt=dt, ooi=ooi_ind)
|
| 220 |
+
acce_actions = get_acce_action_from_trajectory_batch(ooi_pos, ooi_valid, dt=dt, ooi=ooi_ind)
|
| 221 |
+
|
| 222 |
+
# Rescatter labels to decoder-agent indices
|
| 223 |
+
assert config.TRAINING.PREDICT_ALL_AGENTS
|
| 224 |
+
B = data_dict["decoder/agent_valid_mask"].shape[1]
|
| 225 |
+
|
| 226 |
+
full_turn_actions = np.full((B, ), -1, dtype=int)
|
| 227 |
+
full_acce_actions = np.full((B, ), -1, dtype=int)
|
| 228 |
+
|
| 229 |
+
label_mask = np.random.binomial(1, mask_probability, size=len(ooi_ind))
|
| 230 |
+
label_invalid_mask = label_mask == 1
|
| 231 |
+
|
| 232 |
+
turn_actions[label_invalid_mask] = -1
|
| 233 |
+
acce_actions[label_invalid_mask] = -1
|
| 234 |
+
|
| 235 |
+
full_turn_actions[ooi_ind] = turn_actions
|
| 236 |
+
full_acce_actions[ooi_ind] = acce_actions
|
| 237 |
+
|
| 238 |
+
data_dict["decoder/label_turning"] = full_turn_actions
|
| 239 |
+
data_dict["decoder/label_acceleration"] = full_acce_actions
|
| 240 |
+
|
| 241 |
+
return data_dict
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def prepare_safety_label(*, data_dict, dt, mask_probability, config):
|
| 245 |
+
ooi_ind = data_dict["decoder/labeled_agent_id"]
|
| 246 |
+
|
| 247 |
+
ooi_pos = utils.extract_data_by_agent_indices(data_dict["decoder/agent_position"], ooi_ind, agent_dim=1)[..., :2]
|
| 248 |
+
ooi_valid = utils.extract_data_by_agent_indices(
|
| 249 |
+
data_dict["decoder/agent_valid_mask"], ooi_ind, agent_dim=1
|
| 250 |
+
) # (T, A)
|
| 251 |
+
|
| 252 |
+
# TODO: hardcoded here for now and we assume you can access GT trajectory. This won't work with test dataset.
|
| 253 |
+
assert ooi_pos.shape[0] == 91
|
| 254 |
+
assert ooi_valid.shape[0] == 91
|
| 255 |
+
|
| 256 |
+
safety_actions = get_safety_action_from_trajectory_batch(data_dict, ooi_ind)
|
| 257 |
+
|
| 258 |
+
# Rescatter labels to decoder-agent indices
|
| 259 |
+
assert config.TRAINING.PREDICT_ALL_AGENTS
|
| 260 |
+
num_modeled_agents = data_dict["decoder/agent_valid_mask"].shape[1]
|
| 261 |
+
|
| 262 |
+
full_safety_actions = np.full((num_modeled_agents, ), -1, dtype=int)
|
| 263 |
+
|
| 264 |
+
label_mask = np.random.binomial(1, mask_probability, size=len(ooi_ind))
|
| 265 |
+
label_invalid_mask = label_mask == 1
|
| 266 |
+
|
| 267 |
+
label_invalid_mask[safety_actions == 1] = False # We don't mask collision labels
|
| 268 |
+
|
| 269 |
+
safety_actions[label_invalid_mask] = -1
|
| 270 |
+
|
| 271 |
+
full_safety_actions[ooi_ind] = safety_actions
|
| 272 |
+
|
| 273 |
+
data_dict["decoder/label_safety"] = full_safety_actions
|
| 274 |
+
|
| 275 |
+
return data_dict
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
if __name__ == '__main__':
|
| 279 |
+
scenario_dir = "/Users/claire_liu/validation_interactive_0/cat_reconstructed/sd_reconstructed_v0_ScenarioMap-21.pkl"
|
| 280 |
+
cat_dir = "/Users/claire_liu/validation_interactive_0/save.pkl"
|
| 281 |
+
|
| 282 |
+
import pickle
|
| 283 |
+
|
| 284 |
+
with open(scenario_dir, 'rb') as f:
|
| 285 |
+
scenario_data = pickle.load(f)
|
| 286 |
+
f.close()
|
| 287 |
+
|
| 288 |
+
with open(cat_dir, 'rb') as ff:
|
| 289 |
+
cat_dict = pickle.load(ff)
|
| 290 |
+
ff.close()
|
| 291 |
+
|
| 292 |
+
batch_labels = get_3d_action_label(scenario_data, cat_dict)
|
| 293 |
+
print(batch_labels)
|
scenestreamer/dataset/preprocessor.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
scenestreamer/dataset/scenarionet_utils.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from scenestreamer.utils import wrap_to_pi, rotate
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def overwrite_gt_to_pred_field(data_dict):
|
| 7 |
+
import copy
|
| 8 |
+
new_data_dict = copy.deepcopy(data_dict)
|
| 9 |
+
T, N, _ = data_dict["decoder/agent_position"].shape
|
| 10 |
+
|
| 11 |
+
new_data_dict["decoder/reconstructed_position"] = np.zeros((96, N, 2)).astype(np.float32)
|
| 12 |
+
new_data_dict["decoder/reconstructed_valid_mask"] = np.zeros((
|
| 13 |
+
96,
|
| 14 |
+
N,
|
| 15 |
+
)).astype(bool)
|
| 16 |
+
new_data_dict["decoder/reconstructed_heading"] = np.zeros((
|
| 17 |
+
96,
|
| 18 |
+
N,
|
| 19 |
+
)).astype(np.float32)
|
| 20 |
+
new_data_dict["decoder/reconstructed_velocity"] = np.zeros((96, N, 2)).astype(np.float32)
|
| 21 |
+
|
| 22 |
+
for id in range(N): # overwrite all agents
|
| 23 |
+
traj = new_data_dict["decoder/agent_position"][:91, id, :2].astype(np.float32)
|
| 24 |
+
traj_mask = new_data_dict["decoder/agent_valid_mask"][:91, id].astype(bool)
|
| 25 |
+
theta = new_data_dict['decoder/agent_heading'][:91, id].astype(np.float32)
|
| 26 |
+
vel = new_data_dict['decoder/agent_velocity'][:91, id].astype(np.float32)
|
| 27 |
+
|
| 28 |
+
new_data_dict["decoder/reconstructed_position"][:91, id, :2] = traj
|
| 29 |
+
# new_data_dict["decoder/reconstructed_position"][:91, id, 2] = 0.0
|
| 30 |
+
new_data_dict["decoder/reconstructed_valid_mask"][:91, id] = traj_mask
|
| 31 |
+
# print(traj_mask)
|
| 32 |
+
new_data_dict["decoder/reconstructed_heading"][:91, id] = theta
|
| 33 |
+
new_data_dict["decoder/reconstructed_velocity"][:91, id] = vel
|
| 34 |
+
|
| 35 |
+
return new_data_dict
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def create_new_adv(data_dict):
|
| 39 |
+
ego_id = data_dict["decoder/sdc_index"]
|
| 40 |
+
|
| 41 |
+
ego_traj = data_dict["decoder/agent_position"][:, ego_id]
|
| 42 |
+
ego_heading = data_dict["decoder/agent_heading"][:, ego_id]
|
| 43 |
+
ego_velocity = data_dict["decoder/agent_velocity"][:, ego_id]
|
| 44 |
+
ego_shape = data_dict["decoder/agent_shape"][:, ego_id]
|
| 45 |
+
ego_mask = data_dict["decoder/agent_valid_mask"][:, ego_id]
|
| 46 |
+
|
| 47 |
+
last_valid_step = np.where(ego_mask)[0][-1]
|
| 48 |
+
|
| 49 |
+
# Create a new ADV at the final step.
|
| 50 |
+
|
| 51 |
+
adv_mask = np.zeros_like(ego_mask)
|
| 52 |
+
adv_mask[:last_valid_step + 1] = True
|
| 53 |
+
|
| 54 |
+
adv_traj = np.zeros_like(ego_traj)
|
| 55 |
+
adv_heading = np.zeros_like(ego_heading)
|
| 56 |
+
adv_velocity = np.zeros_like(ego_velocity)
|
| 57 |
+
adv_shape = np.zeros_like(ego_shape)
|
| 58 |
+
|
| 59 |
+
# Copy the final pos/head/vel/shape of ego
|
| 60 |
+
adv_traj[last_valid_step] = ego_traj[last_valid_step] + np.random.normal(loc=0.0, scale=0.5, size=3)
|
| 61 |
+
adv_heading[last_valid_step] = ego_heading[last_valid_step] + np.random.normal(loc=0.0, scale=0.1, size=1)
|
| 62 |
+
adv_velocity[last_valid_step] = ego_velocity[last_valid_step] + np.random.normal(loc=0.0, scale=0.5, size=2)
|
| 63 |
+
|
| 64 |
+
for i in range(data_dict["decoder/agent_shape"].shape[0]):
|
| 65 |
+
adv_shape[i] = ego_shape[last_valid_step]
|
| 66 |
+
|
| 67 |
+
# Insert data back:
|
| 68 |
+
data_dict["decoder/agent_position"] = np.concatenate(
|
| 69 |
+
[data_dict["decoder/agent_position"], adv_traj[:, None]], axis=1
|
| 70 |
+
)
|
| 71 |
+
data_dict["decoder/agent_heading"] = np.concatenate(
|
| 72 |
+
[data_dict["decoder/agent_heading"], adv_heading[:, None]], axis=1
|
| 73 |
+
)
|
| 74 |
+
data_dict["decoder/agent_velocity"] = np.concatenate(
|
| 75 |
+
[data_dict["decoder/agent_velocity"], adv_velocity[:, None]], axis=1
|
| 76 |
+
)
|
| 77 |
+
# data_dict["decoder/agent_shape"] = np.concatenate([data_dict["decoder/agent_shape"], adv_shape[:, None]], axis=1)
|
| 78 |
+
|
| 79 |
+
data_dict["decoder/agent_shape"] = np.concatenate([data_dict["decoder/agent_shape"], adv_shape[:, None]], axis=1)
|
| 80 |
+
|
| 81 |
+
data_dict["decoder/agent_valid_mask"] = np.concatenate(
|
| 82 |
+
[data_dict["decoder/agent_valid_mask"], adv_mask[:, None]], axis=1
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
data_dict["decoder/current_agent_shape"] = np.concatenate(
|
| 86 |
+
[data_dict["decoder/current_agent_shape"], data_dict["decoder/current_agent_shape"][ego_id:ego_id + 1]], axis=0
|
| 87 |
+
)
|
| 88 |
+
data_dict["decoder/agent_type"] = np.concatenate(
|
| 89 |
+
[data_dict["decoder/agent_type"], data_dict["decoder/agent_type"][ego_id:ego_id + 1]], axis=0
|
| 90 |
+
)
|
| 91 |
+
data_dict["decoder/agent_id"] = np.concatenate(
|
| 92 |
+
[data_dict["decoder/agent_id"], [len(data_dict["decoder/agent_id"])]], axis=0
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# Add ADV into OOI:
|
| 96 |
+
data_dict["decoder/object_of_interest_id"] = np.concatenate(
|
| 97 |
+
[data_dict["decoder/object_of_interest_id"], [len(data_dict["decoder/agent_id"]) - 1]], axis=0
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# Deal with some thing for forward prediction:
|
| 101 |
+
data_dict["decoder/current_agent_valid_mask"] = np.concatenate(
|
| 102 |
+
[data_dict["decoder/current_agent_valid_mask"], [1]], axis=0
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
print("====================================")
|
| 106 |
+
print(
|
| 107 |
+
"The new ADV is created at the final step {}, it's ID is: {}".format(
|
| 108 |
+
last_valid_step,
|
| 109 |
+
len(data_dict["decoder/agent_id"]) - 1
|
| 110 |
+
)
|
| 111 |
+
)
|
| 112 |
+
print("====================================")
|
| 113 |
+
|
| 114 |
+
return data_dict
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def overwrite_to_scenario_description(output_dict_mode, original_SD, ooi=None, adv_id=None):
|
| 118 |
+
# overwrite original SD with all predicted ooi trajectories included
|
| 119 |
+
# import pdb; pdb.set_trace()
|
| 120 |
+
if not ooi:
|
| 121 |
+
ooi = output_dict_mode['decoder/agent_id'] # overwrite all agents
|
| 122 |
+
sdc_track_name = original_SD['metadata']['sdc_id']
|
| 123 |
+
adv_track_name = str(output_dict_mode['decoder/track_name'][int(adv_id)].item())
|
| 124 |
+
|
| 125 |
+
for id in ooi:
|
| 126 |
+
agent_track_name = str(output_dict_mode['decoder/track_name'][id].item())
|
| 127 |
+
|
| 128 |
+
# begin to overwrite original scenario_data
|
| 129 |
+
agent_traj = output_dict_mode["decoder/agent_position"][:91, id, ]
|
| 130 |
+
agent_heading = output_dict_mode["decoder/agent_heading"][:91, id]
|
| 131 |
+
agent_vel = output_dict_mode["decoder/agent_velocity"][:91, id]
|
| 132 |
+
agent_traj_mask = output_dict_mode["decoder/agent_valid_mask"][:91, id]
|
| 133 |
+
|
| 134 |
+
# modify adv info
|
| 135 |
+
# agent_z = original_SD['tracks'][agent_track_name]['state']['position'][10, 2] # fill the z-axis
|
| 136 |
+
# agent_traj_z = np.full((91, 1), agent_z)
|
| 137 |
+
# agent_new_traj = np.concatenate([agent_traj, agent_traj_z], axis=1)
|
| 138 |
+
# print("new_traj:", agent_new_traj.shape)
|
| 139 |
+
original_SD['tracks'][agent_track_name]['state']['position'] = agent_traj
|
| 140 |
+
original_SD['tracks'][agent_track_name]['state']['velocity'] = agent_vel
|
| 141 |
+
original_SD['tracks'][agent_track_name]['state']['heading'] = agent_heading
|
| 142 |
+
original_SD['tracks'][agent_track_name]['state']['valid'] = agent_traj_mask
|
| 143 |
+
|
| 144 |
+
length = original_SD['tracks'][agent_track_name]['state']['length'][10]
|
| 145 |
+
width = original_SD['tracks'][agent_track_name]['state']['width'][10]
|
| 146 |
+
height = original_SD['tracks'][agent_track_name]['state']['height'][10]
|
| 147 |
+
original_SD['tracks'][agent_track_name]['state']['length'] = np.full((91, ), length)
|
| 148 |
+
original_SD['tracks'][agent_track_name]['state']['width'] = np.full((91, ), width)
|
| 149 |
+
original_SD['tracks'][agent_track_name]['state']['height'] = np.full((91, ), height)
|
| 150 |
+
|
| 151 |
+
original_SD['metadata']['selected_adv_id'] = adv_track_name
|
| 152 |
+
|
| 153 |
+
return original_SD
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def overwrite_to_scenario_description_new_agent(output_dict_mode, original_SD, ooi=None):
|
| 157 |
+
# overwrite original SD with all predicted ooi trajectories included
|
| 158 |
+
ooi = output_dict_mode['decoder/agent_id'] # overwrite all agents
|
| 159 |
+
|
| 160 |
+
adv_track_name = 'new_adv_agent'
|
| 161 |
+
original_SD['tracks'][adv_track_name] = {'state': {}, 'type': 'VEHICLE', 'metadata': {}}
|
| 162 |
+
sdc_track_name = original_SD['metadata']['sdc_id']
|
| 163 |
+
|
| 164 |
+
for id in ooi:
|
| 165 |
+
if id == ooi[-1]:
|
| 166 |
+
agent_track_name = 'new_adv_agent'
|
| 167 |
+
else:
|
| 168 |
+
agent_track_name = str(output_dict_mode['decoder/track_name'][id].item())
|
| 169 |
+
|
| 170 |
+
# begin to overwrite original scenario_data
|
| 171 |
+
agent_traj = output_dict_mode["decoder/agent_position"][:, id, ]
|
| 172 |
+
agent_heading = output_dict_mode["decoder/agent_heading"][:, id]
|
| 173 |
+
agent_vel = output_dict_mode["decoder/agent_velocity"][:, id]
|
| 174 |
+
agent_traj_mask = output_dict_mode["decoder/agent_valid_mask"][:, id]
|
| 175 |
+
|
| 176 |
+
# modify adv info
|
| 177 |
+
# agent_z = original_SD['tracks'][agent_track_name]['state']['position'][10, 2] # fill the z-axis
|
| 178 |
+
# agent_traj_z = np.full((91, 1), agent_z)
|
| 179 |
+
# agent_new_traj = np.concatenate([agent_traj, agent_traj_z], axis=1)
|
| 180 |
+
# print("new_traj:", agent_new_traj.shape)
|
| 181 |
+
original_SD['tracks'][agent_track_name]['state']['position'] = agent_traj
|
| 182 |
+
|
| 183 |
+
original_SD['tracks'][agent_track_name]['state']['velocity'] = agent_vel
|
| 184 |
+
original_SD['tracks'][agent_track_name]['state']['heading'] = agent_heading
|
| 185 |
+
original_SD['tracks'][agent_track_name]['state']['valid'] = agent_traj_mask
|
| 186 |
+
|
| 187 |
+
length = original_SD['tracks'][sdc_track_name]['state']['length'][10]
|
| 188 |
+
width = original_SD['tracks'][sdc_track_name]['state']['width'][10]
|
| 189 |
+
height = original_SD['tracks'][sdc_track_name]['state']['height'][10]
|
| 190 |
+
original_SD['tracks'][agent_track_name]['state']['length'] = np.full((91, ), length)
|
| 191 |
+
original_SD['tracks'][agent_track_name]['state']['width'] = np.full((91, ), width)
|
| 192 |
+
original_SD['tracks'][agent_track_name]['state']['height'] = np.full((91, ), height)
|
| 193 |
+
|
| 194 |
+
original_SD['tracks'][adv_track_name]['metadata']['dataset'] = 'waymo'
|
| 195 |
+
original_SD['tracks'][adv_track_name]['metadata']['object_id'] = 'new_adv_agent'
|
| 196 |
+
original_SD['tracks'][adv_track_name]['metadata']['track_length'] = 91
|
| 197 |
+
original_SD['tracks'][adv_track_name]['metadata']['type'] = 'VEHICLE'
|
| 198 |
+
original_SD['metadata']['new_adv_id'] = 'new_adv_agent'
|
| 199 |
+
original_SD['metadata']['objects_of_interest'].append('new_adv_agent')
|
| 200 |
+
tracks_length = len(list(original_SD['tracks'].keys()))
|
| 201 |
+
original_SD['metadata']['tracks_to_predict']['new_adv_agent'] = {
|
| 202 |
+
'difficulty': 0,
|
| 203 |
+
'object_type': 'VEHICLE',
|
| 204 |
+
'track_id': 'new_adv_agent',
|
| 205 |
+
'track_index': tracks_length - 1
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
return original_SD
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def transform_to_global_coordinate(data_dict):
|
| 212 |
+
map_center = data_dict["metadata/map_center"].reshape(-1, 1, 3) # (1,1,3)
|
| 213 |
+
assert "decoder/agent_position" in data_dict, "Have you set EVALUATION.PREDICT_ALL_AGENTS to False?"
|
| 214 |
+
T, N, _ = data_dict["decoder/agent_position"].shape
|
| 215 |
+
assert data_dict["decoder/agent_position"].ndim == 3
|
| 216 |
+
data_dict["decoder/agent_position"] += map_center
|
| 217 |
+
|
| 218 |
+
return data_dict
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def _overwrite_datadict_all_agents(data_dict):
|
| 222 |
+
import copy
|
| 223 |
+
new_data_dict = copy.deepcopy(data_dict)
|
| 224 |
+
|
| 225 |
+
T, N, _ = data_dict["decoder/reconstructed_position"].shape
|
| 226 |
+
|
| 227 |
+
for id in range(N): # overwrite all agents
|
| 228 |
+
traj = data_dict["decoder/reconstructed_position"][:91, id, ]
|
| 229 |
+
traj_mask = data_dict["decoder/reconstructed_valid_mask"][:91, id]
|
| 230 |
+
theta = data_dict['decoder/reconstructed_heading'][:91, id]
|
| 231 |
+
vel = data_dict['decoder/reconstructed_velocity'][:91, id]
|
| 232 |
+
|
| 233 |
+
new_data_dict["decoder/agent_position"][:, id, :2] = traj
|
| 234 |
+
new_data_dict["decoder/agent_position"][:, id, 2] = 0.0
|
| 235 |
+
new_data_dict["decoder/agent_valid_mask"][:, id] = traj_mask
|
| 236 |
+
new_data_dict["decoder/agent_heading"][:, id] = theta
|
| 237 |
+
new_data_dict["decoder/agent_velocity"][:, id] = vel
|
| 238 |
+
|
| 239 |
+
return new_data_dict
|